From 7ea243f037ea637c876c1c2308eae5b2fd5706c4 Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Thu, 7 Oct 2021 18:37:53 +0000 Subject: [PATCH] test load model, perform normally --- .gitignore | 2 +- arguments.py | 235 ++--------- data_utils/configure_data.py | 3 +- model/__init__.py | 18 - model/base_model.py | 5 +- model/cuda2d_model.py | 16 +- move_weights.py | 105 +++++ mpu/__init__.py | 5 - mpu/layers.py | 2 - mpu/random.py | 386 ------------------ mpu/transformer.py | 12 +- pretrain_cogview2.py | 170 ++++++++ pretrain_gpt2.py | 16 +- scripts/ds_config_zero.json | 2 +- scripts/pretrain_cogview2.sh | 59 +++ scripts/pretrain_multiple_nodes.sh | 46 +-- .../cuda_2d_text2image.sh | 0 {scripts => scripts_old}/image2text.sh | 0 .../low_level_super_resolution.sh | 0 {scripts => scripts_old}/post_selection.sh | 0 scripts_old/pretrain_multiple_nodes.sh | 74 ++++ .../pretrain_single_node.sh | 0 {scripts => scripts_old}/super_resolution.sh | 0 {scripts => scripts_old}/testnan.sh | 0 {scripts => scripts_old}/text2image.sh | 0 tokenization/cogview/unified_tokenizer.py | 2 +- training/deepspeed_training.py | 23 +- training/learning_rates.py | 4 +- training/model_io.py | 4 +- utils.py | 311 -------------- 30 files changed, 516 insertions(+), 984 deletions(-) create mode 100644 move_weights.py delete mode 100755 mpu/random.py create mode 100755 pretrain_cogview2.py create mode 100755 scripts/pretrain_cogview2.sh rename {scripts => scripts_old}/cuda_2d_text2image.sh (100%) rename {scripts => scripts_old}/image2text.sh (100%) rename {scripts => scripts_old}/low_level_super_resolution.sh (100%) rename {scripts => scripts_old}/post_selection.sh (100%) create mode 100755 scripts_old/pretrain_multiple_nodes.sh rename {scripts => scripts_old}/pretrain_single_node.sh (100%) rename {scripts => scripts_old}/super_resolution.sh (100%) rename {scripts => scripts_old}/testnan.sh (100%) rename {scripts => scripts_old}/text2image.sh (100%) diff --git a/.gitignore b/.gitignore index c313862..45ed55c 100755 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ _cache* .vscode/ samples/ hostfile -pretrained/checkpoints +pretrained/ *.png *.jpg *.jpeg diff --git a/arguments.py b/arguments.py index 7ea0432..7344533 100755 --- a/arguments.py +++ b/arguments.py @@ -39,60 +39,20 @@ def add_model_config_args(parser): help='layer norm epsilon') group.add_argument('--hidden-dropout', type=float, default=0.1, help='dropout probability for hidden state transformer') - group.add_argument('--max-position-embeddings', type=int, default=512, + group.add_argument('--max-sequence-length', type=int, default=512, help='maximum number of position embeddings to use') - group.add_argument('--vocab-size', type=int, default=30522, + group.add_argument('--vocab-size', type=int, default=0, help='vocab size to use for non-character-level ' 'tokenization. This value will only be used when ' 'creating a tokenizer') - group.add_argument('--deep-init', action='store_true', - help='initialize bert model similar to gpt2 model.' - 'scales initialization of projection layers by a ' - 'factor of 1/sqrt(2N). Necessary to train bert ' - 'models larger than BERT-Large.') group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, help='Pad the vocab size to be divisible by this value.' 'This is added for computational efficieny reasons.') - group.add_argument('--cpu-optimizer', action='store_true', - help='Run optimizer on CPU') - group.add_argument('--cpu_torch_adam', action='store_true', - help='Use Torch Adam as optimizer on CPU.') - - group.add_argument('--max-position-embeddings-finetune', type=int, default=-1, - help='maximum number of position embeddings to use in finetune') group.add_argument('--sandwich-ln', action='store_true', help='add sandwich ln in cogview.') return parser -def add_fp16_config_args(parser): - """Mixed precision arguments.""" - - group = parser.add_argument_group('fp16', 'fp16 configurations') - - group.add_argument('--fp16', action='store_true', - help='Run model in fp16 mode') - group.add_argument('--fp32-embedding', action='store_true', - help='embedding in fp32') - group.add_argument('--fp32-layernorm', action='store_true', - help='layer norm in fp32') - group.add_argument('--fp32-tokentypes', action='store_true', - help='embedding token types in fp32') - group.add_argument('--fp32-allreduce', action='store_true', - help='all-reduce in fp32') - group.add_argument('--hysteresis', type=int, default=2, - help='hysteresis for dynamic loss scaling') - group.add_argument('--loss-scale', type=float, default=None, - help='Static loss scaling, positive power of 2 ' - 'values can improve fp16 convergence. If None, dynamic' - 'loss scaling is used.') - group.add_argument('--loss-scale-window', type=float, default=1000, - help='Window over which to raise/lower dynamic scale') - group.add_argument('--min-scale', type=float, default=1, - help='Minimum loss scale for dynamic loss scale') - - return parser - def add_training_args(parser): """Training arguments.""" @@ -110,10 +70,6 @@ def add_training_args(parser): 'with larger models and sequences') group.add_argument('--checkpoint-num-layers', type=int, default=1, help='chunk size (number of layers) for checkpointing') - group.add_argument('--deepspeed-activation-checkpointing', action='store_true', - help='uses activation checkpointing from deepspeed') - group.add_argument('--clip-grad', type=float, default=1.0, - help='gradient clipping') group.add_argument('--train-iters', type=int, default=1000000, help='total number of iterations to train over all training runs') group.add_argument('--log-interval', type=int, default=50, @@ -123,16 +79,6 @@ def add_training_args(parser): group.add_argument('--summary-dir', type=str, default="", help="The directory to store the summary") group.add_argument('--seed', type=int, default=1234, help='random seed') - group.add_argument('--img-tokenizer-path', type=str, default=None, - help='The checkpoint file path of image tokenizer.') - group.add_argument('--img-tokenizer-num-tokens', type=int, default=None, - help='The num tokens of image tokenizer. ONLY use for pretraining with img-tokenizer UNKNOW.') - # Batch prodecuer arguments - group.add_argument('--reset-position-ids', action='store_true', - help='Reset posistion ids after end-of-document token.') - group.add_argument('--reset-attention-mask', action='store_true', - help='Reset self attention maske after ' - 'end-of-document token.') # Learning rate. group.add_argument('--lr-decay-iters', type=int, default=None, @@ -147,31 +93,32 @@ def add_training_args(parser): group.add_argument('--warmup', type=float, default=0.01, help='percentage of data to warmup on (.01 = 1% of all ' 'training iters). Default 0.01') - group.add_argument('--restart-iter', type=int, default=0, - help='restart with warmup from this iteration.') # model checkpointing group.add_argument('--save', type=str, default=None, help='Output directory to save checkpoints to.') + group.add_argument('--load', type=str, default=None, + help='Path to a directory containing a model checkpoint.') group.add_argument('--save-interval', type=int, default=5000, help='number of iterations between saves') - group.add_argument('--no-save-optim', action='store_true', - help='Do not save current optimizer.') + # group.add_argument('--no-save-optim', action='store_true', + # help='Do not save current optimizer.') + # group.add_argument('--no-load-optim', action='store_true', + # help='Do not load optimizer when loading checkpoint.') group.add_argument('--no-save-rng', action='store_true', help='Do not save current rng state.') - group.add_argument('--load', type=str, default=None, - help='Path to a directory containing a model checkpoint.') - group.add_argument('--no-load-optim', action='store_true', - help='Do not load optimizer when loading checkpoint.') group.add_argument('--no-load-rng', action='store_true', help='Do not load rng state when loading checkpoint.') - group.add_argument('--finetune', action='store_true', - help='Load model for finetuning. Do not load optimizer ' - 'or rng state from checkpoint and set iteration to 0. ' - 'Assumed when loading a release checkpoint.') + group.add_argument('--mode', type=str, + default='pretrain', + choices=['pretrain', + 'finetune', + 'inference' + ], + help='what type of task to use, will influence auto-warmup, exp name, iteration') group.add_argument('--resume-dataloader', action='store_true', help='Resume the dataloader when resuming training. ' 'Does not apply to tfrecords dataloader, try resuming' - 'with a different seed in this case.') + 'with a different seed in this case.') # distributed training args group.add_argument('--distributed-backend', default='nccl', help='which backend to use for distributed ' @@ -180,11 +127,9 @@ def add_training_args(parser): group.add_argument('--local_rank', type=int, default=None, help='local rank passed from distributed launcher') - # loss scale - group.add_argument('--txt-loss-scale', type=float, default=1) - group.add_argument('--fast-load', action='store_true', - help='load checkpoints without locks.') - + group.add_argument('--fp16', action='store_true', + help='Run model in fp16 mode') + return parser @@ -243,9 +188,9 @@ def add_data_args(parser): group.add_argument('--model-parallel-size', type=int, default=1, help='size of the model parallel.') - group.add_argument('--shuffle', action='store_true', - help='Shuffle data. Shuffling is deterministic ' - 'based on seed and current epoch.') + # group.add_argument('--shuffle', action='store_true', + # help='Shuffle data. Shuffling is deterministic ' + # 'based on seed and current epoch.') group.add_argument('--train-data', nargs='+', default=None, help='Whitespace separated filenames or corpora names ' 'for training.') @@ -261,20 +206,6 @@ def add_data_args(parser): group.add_argument('--num-workers', type=int, default=2, help="""Number of workers to use for dataloading""") - group.add_argument('--dataset-type', type=str, - default='TokenizedDataset', - choices=['TokenizedDataset', - 'TextCodeDataset', - 'CompactBinaryDataset', - 'BinaryDataset' - ], - help='what type of dataset to use') - - group.add_argument('--max-memory-length', type=int, default=2048, - help="max memory buffer for attention") - group.add_argument('--new-dataset-path', type=str, default=None, - help='The folder we will dynamically check for lmdbs during training.') - return parser def add_generation_api_args(parser): @@ -290,102 +221,59 @@ def add_generation_api_args(parser): group.add_argument('--device', default=None) return parser - -def add_sparse_args(parser): + +def add_tokenization_args(parser): """sparse attention arguments.""" - group = parser.add_argument_group('Sparse Attention', 'sparse configurations') - group.add_argument('--sparse-type', type=str, default='standard', - choices=['standard', 'torch_1d', 'cuda_2d'], - help='whether use sparse attention.') # TODO: Temporally not using is-sparse==2 (not optimized), use 0 for inference. - # for torch_1d - group.add_argument("--query-window", type=int, default=128) - group.add_argument("--key-window-times", type=int, default=6) - group.add_argument("--num-pivot", type=int, default=768) - # for cuda_2d - group.add_argument("--kernel-size", type=int, default=9) - group.add_argument("--kernel-size2", type=int, default=7) - group.add_argument("--layout", type=str, default='64,1088,5184') + group = parser.add_argument_group('Tokenization', 'tokenization configurations') + group.add_argument('--tokenizer-type', type=str, default='fake', help='type name of tokenizer') + + group.add_argument('--img-tokenizer-path', type=str, default=None, + help='The checkpoint file path of image tokenizer.') return parser -def make_sparse_config(args): - args.layout = [int(x) for x in args.layout.split(',')] - sparse_config = argparse.Namespace(sparse_type=args.sparse_type) - sparse_config.layout = args.layout - if args.sparse_type == 'standard': - pass - if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation': - sparse_config.kernel_size = args.kernel_size - sparse_config.kernel_size2 = args.kernel_size2 - elif args.sparse_type == 'torch_1d': - raise NotImplementedError - args.sparse_config = sparse_config - -def get_args(): + + +def get_args(args_list=None): """Parse all the args.""" - parser = argparse.ArgumentParser(description='PyTorch CogView Model') + parser = argparse.ArgumentParser(description='Swiss Army Transformer') parser = add_model_config_args(parser) - parser = add_fp16_config_args(parser) parser = add_training_args(parser) parser = add_evaluation_args(parser) - parser = add_text_generate_args(parser) parser = add_data_args(parser) + parser = add_tokenization_args(parser) + parser = add_text_generate_args(parser) parser = add_generation_api_args(parser) - parser = add_sparse_args(parser) # Include DeepSpeed configuration arguments parser = deepspeed.add_config_arguments(parser) - args = parser.parse_args() - make_sparse_config(args) + args = parser.parse_args(args_list) if not args.train_data: print('WARNING: No training data specified') - elif args.sparse_type == 'torch_1d' and (args.max_position_embeddings - 1) % args.query_window != 0: - raise ValueError('During sparse training, the sequence length must be exactly divided by window_size.') args.cuda = torch.cuda.is_available() args.rank = int(os.getenv('RANK', '0')) args.world_size = int(os.getenv("WORLD_SIZE", '1')) - if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: - mpi_define_env(args) - elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'): - # We are using (OpenMPI) mpirun for launching distributed data parallel processes - local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) - local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE')) - - # Possibly running with Slurm - num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1')) - nodeid = int(os.getenv('SLURM_NODEID', '0')) - - args.local_rank = local_rank - args.rank = nodeid * local_size + local_rank - args.world_size = num_nodes * local_size + args.model_parallel_size = min(args.model_parallel_size, args.world_size) if args.rank == 0: print('using world size: {} and model-parallel size: {} '.format( args.world_size, args.model_parallel_size)) - args.dynamic_loss_scale = False - if args.loss_scale is None: - args.dynamic_loss_scale = True - if args.rank == 0: - print(' > using dynamic loss scaling') - - # The args fp32_* or fp16_* meant to be active when the - # args fp16 is set. So the default behaviour should all - # be false. - if not args.fp16: - args.fp32_embedding = False - args.fp32_tokentypes = False - args.fp32_layernorm = False - if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None: with open(args.deepspeed_config) as file: deepspeed_config = json.load(file) + if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: + args.fp16 = True + else: + args.fp16 = False + if args.checkpoint_activations: + args.deepspeed_activation_checkpointing = True if "train_micro_batch_size_per_gpu" in deepspeed_config: args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] if "gradient_accumulation_steps" in deepspeed_config: @@ -399,40 +287,3 @@ def get_args(): return args -def mpi_define_env(args): - ''' For training CogView via MPI to setup the connection. - Omit this function if use the basic deepspeed pdsh runner. - ''' - from mpi4py import MPI - import subprocess - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() - - master_addr = None - if rank == 0: - hostname_cmd = ["hostname -I"] - result = subprocess.check_output(hostname_cmd, shell=True) - master_addr = result.decode('utf-8').split()[0] - master_addr = comm.bcast(master_addr, root=0) - - # Determine local rank by assuming hostnames are unique - proc_name = MPI.Get_processor_name() - all_procs = comm.allgather(proc_name) - local_rank = sum([i == proc_name for i in all_procs[:rank]]) - - os.environ['RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - args.local_rank = local_rank - args.world_size = world_size - args.rank = rank - os.environ['MASTER_ADDR'] = master_addr - os.environ['MASTER_PORT'] = "29500" # TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 - - print( - "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" - .format(os.environ['RANK'], - args.local_rank, - os.environ['WORLD_SIZE'], - os.environ['MASTER_ADDR'], - os.environ['MASTER_PORT'])) diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py index 6556673..95c3bd1 100755 --- a/data_utils/configure_data.py +++ b/data_utils/configure_data.py @@ -52,7 +52,7 @@ def make_data_loader(dataset, batch_size, num_iters, args): return data_loader -def make_dataset_full(dataset_type, path, split, args, create_dataset_function, **kwargs): +def make_dataset_full(path, split, args, create_dataset_function, **kwargs): """function to create datasets+tokenizers for common options""" print('make dataset ...', path) if split is None: @@ -93,7 +93,6 @@ def make_loaders(args, create_dataset_function): data_set_args = { 'path': args.train_data, - 'dataset_type': args.dataset_type, 'split': split, } diff --git a/model/__init__.py b/model/__init__.py index beb1271..e69de29 100755 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,18 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization -from .gpt2_modeling import GPT2Model - diff --git a/model/base_model.py b/model/base_model.py index 7b7769c..2b74fb0 100644 --- a/model/base_model.py +++ b/model/base_model.py @@ -12,7 +12,6 @@ import sys import math import random import torch -from functools import partial from mpu import BaseTransformer @@ -28,7 +27,7 @@ class BaseModel(torch.nn.Module): vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, - max_sequence_length=args.max_position_embeddings, + max_sequence_length=args.max_sequence_length, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, @@ -58,7 +57,7 @@ class BaseModel(torch.nn.Module): hooks = {} for name in names: if hasattr(self, name): - hooks[name] = partial(getattr(self, name), self) + hooks[name] = getattr(self, name) return hooks def disable_untrainable_params(self): diff --git a/model/cuda2d_model.py b/model/cuda2d_model.py index a9e1175..661e49f 100644 --- a/model/cuda2d_model.py +++ b/model/cuda2d_model.py @@ -20,8 +20,8 @@ from .mixins import PositionEmbeddingMixin, AttentionMixin from mpu.transformer import split_tensor_along_last_dim from mpu.local_attention_function import f_similar, f_weighting -from mpu.random import get_cuda_rng_tracker from mpu.utils import sqrt +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker class Cuda2dModel(BaseModel): @@ -88,7 +88,19 @@ class Cuda2dModel(BaseModel): output_1 = dense_plus(context_layer1) output = torch.cat((output_0, output_1), dim=1) - return output + return output, None + + def disable_untrainable_params(self): + self.transformer.requires_grad_(False) + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations') + group.add_argument("--kernel-size", type=int, default=9) + group.add_argument("--kernel-size2", type=int, default=7) + group.add_argument("--layout", type=str, default='64,1088,5184') + group.add_argument("--new-sequence-length", type=int, default=5185) + return parser def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, **kwargs): ''' diff --git a/move_weights.py b/move_weights.py new file mode 100644 index 0000000..7d41f2b --- /dev/null +++ b/move_weights.py @@ -0,0 +1,105 @@ +# %% +import torch +old = torch.load('pretrained/cogview/cogview-base/142000/mp_rank_00_model_states.pt', map_location='cpu') + +old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight'] +del old['module']['word_embeddings.weight'] + +from model.base_model import BaseModel +import argparse +import os +args = argparse.Namespace( + num_layers=48, + vocab_size=58240, + hidden_size=2560, + num_attention_heads=40, + max_sequence_length=1089, + hidden_dropout=0.1, + attention_dropout=0.1, + checkpoint_activations=True, + checkpoint_num_layers=1, + sandwich_ln=True, + model_parallel_size=1, + world_size=1, + rank=0 + ) +init_method = 'tcp://' +master_ip = os.getenv('MASTER_ADDR', 'localhost') +master_port = os.getenv('MASTER_PORT', '6000') +init_method += master_ip + ':' + master_port +torch.distributed.init_process_group( + backend='nccl', + world_size=args.world_size, rank=args.rank,init_method=init_method) +import mpu + # Set the model-parallel / data-parallel communicators. +mpu.initialize_model_parallel(args.model_parallel_size) +print('bg') +model = BaseModel(args) +missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False) +torch.save(old, 'pretrained/cogview/cogview-base/142000/mp_rank_00_model_states.pt') +# %% +import torch +old = torch.load('pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt', map_location='cpu') +old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight'] +del old['module']['word_embeddings.weight'] +#%% +from model.cuda2d_model import Cuda2dModel +import argparse +import os +args = argparse.Namespace( + num_layers=48, + vocab_size=58240, + hidden_size=2560, + num_attention_heads=40, + max_sequence_length=1089, + hidden_dropout=0.1, + attention_dropout=0.1, + checkpoint_activations=True, + checkpoint_num_layers=1, + sandwich_ln=True, + model_parallel_size=1, + world_size=1, + rank=0, + new_sequence_length=1089+4096, + layout='0,64,1088,5184', + kernel_size=9, + kernel_size2=7 + ) +# %% +init_method = 'tcp://' +master_ip = os.getenv('MASTER_ADDR', 'localhost') +master_port = os.getenv('MASTER_PORT', '6000') +init_method += master_ip + ':' + master_port +torch.distributed.init_process_group( + backend='nccl', + world_size=args.world_size, rank=args.rank,init_method=init_method) +import mpu + # Set the model-parallel / data-parallel communicators. +mpu.initialize_model_parallel(args.model_parallel_size) +print('bg') +#%% +model = Cuda2dModel(args) + +#%% +old['module']['mixins.0.position_embeddings.weight'] = old['module']['transformer.position_embeddings_plus.weight'] +del old['module']['transformer.position_embeddings_plus.weight'] + +for i in range(48): + old['module'][f'mixins.1.query_key_value.{i}.weight'] = \ + old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight'] + del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight'] + old['module'][f'mixins.1.query_key_value.{i}.bias'] = \ + old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias'] + del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias'] + old['module'][f'mixins.1.dense.{i}.weight'] = \ + old['module'][f'transformer.layers.{i}.attention.dense_plus.weight'] + del old['module'][f'transformer.layers.{i}.attention.dense_plus.weight'] + old['module'][f'mixins.1.dense.{i}.bias'] = \ + old['module'][f'transformer.layers.{i}.attention.dense_plus.bias'] + del old['module'][f'transformer.layers.{i}.attention.dense_plus.bias'] +# %% +missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False) + +# %% +torch.save(old, 'pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt') +# %% diff --git a/mpu/__init__.py b/mpu/__init__.py index 358f997..c2ebd2f 100755 --- a/mpu/__init__.py +++ b/mpu/__init__.py @@ -40,10 +40,5 @@ from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region -from .random import checkpoint -from .random import partition_activations_in_checkpoint -from .random import get_cuda_rng_tracker -from .random import model_parallel_cuda_manual_seed - from .transformer import BaseTransformer from .transformer import LayerNorm diff --git a/mpu/layers.py b/mpu/layers.py index 2739bd9..72ec990 100755 --- a/mpu/layers.py +++ b/mpu/layers.py @@ -33,9 +33,7 @@ from .mappings import copy_to_model_parallel_region from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region -from .random import get_cuda_rng_tracker from .utils import divide -from .utils import split_tensor_along_last_dim from .utils import VocabUtility diff --git a/mpu/random.py b/mpu/random.py deleted file mode 100755 index 69f4cec..0000000 --- a/mpu/random.py +++ /dev/null @@ -1,386 +0,0 @@ -# coding=utf-8 -#Modified by Samyam Rajbhandari -#Used to partition the activations stored for backward propagation -#Therefore reduces the memory consumption - -# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -import contextlib -import torch.distributed as dist -import torch -from torch import _C -from torch.cuda import _lazy_call, device as device_ctx_manager -#from torch.utils.checkpoint import detach_variable - - -import torch.distributed as dist -PARTITION_ACTIVATIONS = False -PA_CORRECTNESS_TEST= False - -def see_memory_usage(message, force=False): - if not force: - return - dist.barrier() - if dist.get_rank() == 0: - print(message) - print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes") - print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes") - print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes") - print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes") - print(" ") - #input("Press Any Key To Continue ..") - - -from .initialize import get_data_parallel_rank -from .initialize import get_model_parallel_rank -from .initialize import get_model_parallel_world_size -from .initialize import get_model_parallel_group - -mp_rank = None #get_model_parallel_rank() -mp_size = None #get_model_parallel_world_size() -mp_group = None #get_model_parallel_group() - -# Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' -transport_stream = None -cuda_device=None -def detach_variable(inputs, device=None): - if isinstance(inputs, tuple): - out = [] - for inp in inputs: - if not isinstance(inp, torch.Tensor): - out.append(inp) - continue - - requires_grad = inp.requires_grad - - if device is not None: - x = inp.to(device=device) - else: - x = inp - - x = x.detach() - x.requires_grad = requires_grad - out.append(x) - return tuple(out) - else: - raise RuntimeError( - "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) - -def _set_cuda_rng_state(new_state, device=-1): - """Sets the random number generator state of the current GPU. - - Argumentss: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ - if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): - # older PyTorch - def cb(): - with device_ctx_manager(device): - _C._cuda_setRNGState(new_state) - else: - # newer PyTorch - if device == -1: - device = torch.device('cuda') - elif isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device('cuda', device) - - def cb(): - idx = device.index - if idx is None: - idx = torch.cuda.current_device() - default_generator = torch.cuda.default_generators[idx] - default_generator.set_state(new_state) - - _lazy_call(cb) - - - -class CudaRNGStatesTracker: - """Tracker for the cuda RNG states. - - Using the `add` method, a cuda rng state is initialized based on - the input `seed` and is assigned to `name`. Later, by forking the - rng state, we can perform operations and return to our starting - cuda state. - """ - def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() - - def reset(self): - """Set to the initial state (no tracker).""" - self.states_ = {} - self.seeds_ = set() - - def get_states(self): - """Get rng states. Copy the dictionary so we have direct - pointers to the states, not just a pointer to the dictionary.""" - states = {} - for name in self.states_: - states[name] = self.states_[name] - return states - - def set_states(self, states): - """Set the rng states. For efficiency purposes, we do not check - the size of seed for compatibility.""" - self.states_ = states - - def add(self, name, seed): - """Track the rng state.""" - # Check seed is not already used. - if seed in self.seeds_: - raise Exception('seed {} already exists'.format(seed)) - self.seeds_.add(seed) - # Check that state is not already defined. - if name in self.states_: - raise Exception('cuda rng state {} already exists'.format(name)) - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) - - @contextlib.contextmanager - def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - raise Exception('cuda rng state {} is not added'.format(name)) - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() - # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) - # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) - - -# RNG tracker object. -_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - - -def get_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _CUDA_RNG_STATE_TRACKER - - -def model_parallel_cuda_manual_seed(seed): - """Initialize model parallel cuda seed. - - This function should be called after the model parallel is - initialized. Also, no torch.cuda.manual_seed should be called - after this function. Basically, this is replacement for that - function. - Two set of RNG states are tracked: - default state: This is for data parallelism and is the same among a - set of model parallel GPUs but different across - different model paralle groups. This is used for - example for dropout in the non-model-parallel regions. - model-parallel state: This state is different among a set of model - parallel GPUs, but the same across data parallel - groups. This is used for example for dropout in - model parallel regions. - """ - # 2718 is just for fun and any POSITIVE value will work. - offset = seed + 2718 - model_parallel_seed = offset + get_model_parallel_rank() - # Data parallel gets the original sedd. - data_parallel_seed = seed - - if torch.distributed.get_rank() == 0: - print('> initializing model parallel cuda seeds on global rank {}, ' - 'model parallel rank {}, and data parallel rank {} with ' - 'model parallel seed: {} and data parallel seed: {}'.format( - torch.distributed.get_rank(), get_model_parallel_rank(), - get_data_parallel_rank(), model_parallel_seed, - data_parallel_seed), flush=True) - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, - model_parallel_seed) - - -def get_partition_start(item): - global mp_rank, mp_size, mp_group - partition_size = get_partition_size(item) - start = partition_size * mp_rank - return int(start) - -def get_partition_size(item): - global mp_rank, mp_size, mp_group - size = item.numel() - partition_size = size/mp_size - return int(partition_size) - -def get_full_inputs(tensors): - inputs=[] - for i in range(int(len(tensors)/2)-1): - item = tensors[2 * i] - size = tensors[2* i + 1] - partition_size = item.numel() - tensor_size = partition_size * mp_size - flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device) - partitions=[] - for i in range(mp_size): - part_i = flat_tensor.narrow(0, partition_size * i , partition_size) - if i == mp_rank: - part_i.copy_(item) - partitions.append(part_i) - dist.all_gather(partitions,partitions[mp_rank], group=mp_group) - input_tensor = flat_tensor.view(list(size.numpy())) - item.data=input_tensor.data - - inputs.append(item) - inputs.append(tensors[-2]) - - return tuple(inputs) - - - -class CheckpointFunction(torch.autograd.Function): - """This function is adapted from torch.utils.checkpoint with - two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` - 2) the states in the model parallel tracker are also properly - tracked/set/reset. - """ - @staticmethod - def forward(ctx, run_function, *args): - ctx.run_function = run_function - global mp_rank, mp_size, mp_group - if mp_rank is None: - mp_rank = get_model_parallel_rank() - mp_size = get_model_parallel_world_size() - mp_group = get_model_parallel_group() - - - global cuda_device, transport_stream, PARTITION_ACTIVATIONS - if cuda_device is None: - if dist.get_rank() == 0: - print(f"Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}") - - cuda_device = torch.cuda.current_device() - #The transport stream is used to overlap the allgather communication for the activations - #with the computation in the backward pass - transport_stream = torch.cuda.Stream(device=cuda_device) - - if PARTITION_ACTIVATIONS: - inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] - inputs.append(args[-1]) - - #just in case something funky is happening such as reuse of inputs - inputs_cuda = [item.to(cuda_device) for item in args] - - # Copy the rng states. - ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - #ctx.save_for_backward(*args) - with torch.no_grad(): - outputs = run_function(*inputs_cuda) - - del inputs_cuda - - if PARTITION_ACTIVATIONS: - new_args = [] - for arg, inp in zip(args,inputs): - size= torch.tensor(arg.size()) - arg.data = inp.data - new_args.append(arg) - new_args.append(size) - ctx.save_for_backward(*new_args) - else: - ctx.save_for_backward(*args) - - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError("Checkpointing is not compatible with .grad(), " - "please use .backward() if possible") - - global cuda_device, transport_stream, PARTITION_ACTIVATIONS - - if PARTITION_ACTIVATIONS: - with torch.cuda.stream(transport_stream): - inputs = get_full_inputs(ctx.saved_tensors) - detached_inputs = detach_variable(inputs) - else: - inputs = ctx.saved_tensors - detached_inputs = detach_variable(inputs) - - # Store the current states. - bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() - bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - # Set the states to what it used to be before the forward pass. - torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) - - if PARTITION_ACTIVATIONS: - current_stream=torch.cuda.current_stream() - current_stream.wait_stream(transport_stream) - - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - # Set the states back to what it was at the start of this function. - torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs,) - torch.autograd.backward(outputs, args) - return (None,) + tuple(inp.grad for inp in detached_inputs) - - -def checkpoint(function, *args): - """Checkpoint a model or part of the model. - This has been directly copied from torch.utils.checkpoint.""" - return CheckpointFunction.apply(function, *args) - -def partition_activations_in_checkpoint(partition_activation): - global PARTITION_ACTIVATIONS - PARTITION_ACTIVATIONS=partition_activation - if dist.get_rank() == 0: - print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************") - - diff --git a/mpu/transformer.py b/mpu/transformer.py index 9eab060..099dcb5 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -27,9 +27,7 @@ from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedd from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region import deepspeed - -from .random import checkpoint -from .random import get_cuda_rng_tracker +from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu from .utils import split_tensor_along_last_dim @@ -250,11 +248,11 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. - mlp_output = self.mlp(layernorm_output) + mlp_output = self.mlp(layernorm_output, *other_tensors) # Fourth LayerNorm if self.sandwich_ln: - mlp_output = self.fourth_layernorm(mlp_output, *other_tensors) + mlp_output = self.fourth_layernorm(mlp_output) # Second residual connection. output = layernorm_input + mlp_output @@ -280,10 +278,6 @@ class BaseTransformer(torch.nn.Module): hooks={} ): super(BaseTransformer, self).__init__() - if deepspeed.checkpointing.is_configured(): - global get_cuda_rng_tracker, checkpoint - get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker - checkpoint = deepspeed.checkpointing.checkpoint # recording parameters self.parallel_output = parallel_output diff --git a/pretrain_cogview2.py b/pretrain_cogview2.py new file mode 100755 index 0000000..032fade --- /dev/null +++ b/pretrain_cogview2.py @@ -0,0 +1,170 @@ +# -*- encoding: utf-8 -*- +''' +@File : pretrain_gpt2.py +@Time : 2021/10/06 00:58:32 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import argparse +import numpy as np + +import mpu +from arguments import get_args +from model.cuda2d_model import Cuda2dModel +from training.deepspeed_training import training_main +from data_utils import BinaryDataset +from tokenization import get_tokenizer +from tokenization.cogview import TextCodeTemplate + +def get_masks_and_position_ids(data, + loss_mask=None, + attention_mask=None, args=None): + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if attention_mask is None: + assert loss_mask is not None + # loss_mask has n_pad(+1 CLS and [1:] then) zeros, so it is the same as attention_mask, reuse. + attention_mask = loss_mask[:, :args.layout[1]].unsqueeze(-2).expand(batch_size, args.layout[1], args.layout[1]).tril() + for i in range(batch_size): + attention_mask[i].fill_diagonal_(1) + attention_mask = attention_mask.unsqueeze(1) + + # Loss mask. + if loss_mask is None: + loss_mask = torch.ones(data.size(), dtype=data.dtype, device=data.device) + + # Position ids. + assert loss_mask is not None + layout = args.layout + assert seq_length == layout[-1] + n_pads = seq_length - loss_mask.sum(dim=-1).long() + position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long, + device=data.device) + for i in range(batch_size): + torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]], + dtype=torch.long, device=data.device) + torch.arange(layout[2] - layout[1], + out=position_ids[i, layout[1]:], + dtype=torch.long, device=data.device) + + return attention_mask, loss_mask, position_ids + + +def get_batch(data_iterator, args, timers): + # Items and their type. + keys = ['text', 'loss_mask'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens_ = data_b['text'].long() + loss_mask = data_b['loss_mask'].float() + + labels = tokens_[:, 1:].contiguous() + loss_mask = loss_mask[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + attention_mask = None + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_masks_and_position_ids( + tokens, + loss_mask=loss_mask, + attention_mask=attention_mask, + args=args + ) + # Convert + if args.fp16: + attention_mask = attention_mask.half() + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + + # split img & txt positions, [PAD] not included # TODO check enough + tokenizer = get_tokenizer() + img_txt_sep = tokenizer.img_tokenizer.num_tokens + img_indices_bool = (tokens < img_txt_sep) & (loss_mask > 0) + txt_indices_bool = (~img_indices_bool) & (loss_mask > 0) + # Forward model. + logits, *mems = model(tokens, position_ids, attention_mask) + losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) + # scaling loss mask + loss_mask[txt_indices_bool] *= args.txt_loss_scale + loss_mask = loss_mask.view(-1) + + losses = losses.view(-1) * loss_mask + loss = torch.sum(losses) / loss_mask.sum() + # ===================== Log partial losses ======================== # + img_indices_bool2 = img_indices_bool.clone() + img_indices_bool2[:, :args.layout[1]] = False + img_loss2 = losses[img_indices_bool2.view(-1)].detach().sum() / max(img_indices_bool2.sum(), 1) + + img_indices_bool = img_indices_bool.view(-1) + txt_indices_bool = txt_indices_bool.view(-1) + img_loss = losses[img_indices_bool].detach().sum() / max(img_indices_bool.sum(), 1) + txt_loss = losses[txt_indices_bool].detach().sum() / max(txt_indices_bool.sum(), 1) / args.txt_loss_scale + # ===================== END OF BLOCK ======================= # + return loss, {'img_loss': img_loss, 'txt_loss': txt_loss, 'img_loss2': img_loss2} + +def create_dataset_function(path, args): + tokenizer = get_tokenizer() + layout = [64, 64+16**2, 64+16**2+32**2, 64+64**2+16**2+32**2] # FIXME + def process_fn(row): + row = row.astype(np.int64) + codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))] + + text = row[:layout[0]] + text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1] + n_pad = layout[0]-3-len(text) + parts = [ + np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64), + TextCodeTemplate(text, codes[1], tokenizer), + *codes[2:] + ] + ret = np.concatenate(parts, axis=0) + return {'text': ret, + 'loss_mask': np.array([0] * (n_pad+1) + [1] * (len(ret) - n_pad - 1)) # don't predict [CLS] + } + return BinaryDataset(path, process_fn, length_per_sample=layout[-1]) + +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + + py_parser.add_argument('--txt-loss-scale', type=float, default=1) + + Cuda2dModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.layout = [int(x) for x in args.layout.split(',')] + + training_main(args, model_cls=Cuda2dModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py index a8a0ea2..d5df575 100755 --- a/pretrain_gpt2.py +++ b/pretrain_gpt2.py @@ -12,12 +12,13 @@ import sys import math import random import torch +import argparse import numpy as np import mpu from arguments import get_args from model.base_model import BaseModel -from training.deepspeed_training import main +from training.deepspeed_training import training_main from data_utils import BinaryDataset from tokenization import get_tokenizer from tokenization.cogview import TextCodeTemplate @@ -32,6 +33,7 @@ def get_masks_and_position_ids(data, if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device) attention_mask.tril_() + attention_mask.unsqueeze_(1) # Loss mask. if loss_mask is None: @@ -91,7 +93,6 @@ def forward_step(data_iterator, model, args, timers): tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator, args, timers) timers('batch generator').stop() - # Forward model. logits, *mems = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) @@ -124,6 +125,11 @@ def create_dataset_function(path, args): } return BinaryDataset(path, process_fn, length_per_sample=layout[-1]) -if __name__ == '__main__': - args = get_args() - main(args, model_cls=BaseModel, forward_step=forward_step, create_dataset_function=create_dataset_function) +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--new_hyperparam', type=str, default=None) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + training_main(args, model_cls=BaseModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) diff --git a/scripts/ds_config_zero.json b/scripts/ds_config_zero.json index b985583..f43259b 100755 --- a/scripts/ds_config_zero.json +++ b/scripts/ds_config_zero.json @@ -1,5 +1,5 @@ { - "train_micro_batch_size_per_gpu": 2, + "train_micro_batch_size_per_gpu": 1, "gradient_accumulation_steps": 1, "steps_per_print": 1, "gradient_clipping": 0.1, diff --git a/scripts/pretrain_cogview2.sh b/scripts/pretrain_cogview2.sh new file mode 100755 index 0000000..3bcb93d --- /dev/null +++ b/scripts/pretrain_cogview2.sh @@ -0,0 +1,59 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +HOST_FILE_PATH="hostfile_single" + +full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin" +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata" + +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name pretrain-cogview2-test \ + --tokenizer-type cogview \ + --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ + --model-parallel-size ${MP_SIZE} \ + --mode pretrain \ + --num-layers 48 \ + --hidden-size 2560 \ + --num-attention-heads 40 \ + --train-iters 200000 \ + --resume-dataloader \ + --train-data ${full_data} \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .1 \ + --checkpoint-activations \ + --max-sequence-length 1089 \ + --sandwich-ln \ + --fp16 \ + --save-interval 2000 \ + --eval-interval 1000 \ + --save $main_dir/checkpoints \ + --load pretrained/cogview/cogview2-base +" + # --load pretrained/cogview/cogview-base + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_cogview2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh index f099f80..5a09f77 100755 --- a/scripts/pretrain_multiple_nodes.sh +++ b/scripts/pretrain_multiple_nodes.sh @@ -2,7 +2,7 @@ # Change for multinode config -NUM_WORKERS=19 +NUM_WORKERS=1 NUM_GPUS_PER_WORKER=8 MP_SIZE=1 @@ -10,55 +10,39 @@ script_path=$(realpath $0) script_dir=$(dirname $script_path) main_dir=$(dirname $script_dir) -# OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0" OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" HOST_FILE_PATH="hostfile" -# OPTIONS_NCCL="" -# HOST_FILE_PATH="hostfile_single" +HOST_FILE_PATH="hostfile_single" -small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata" full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin" +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata" config_json="$script_dir/ds_config_zero.json" gpt_options=" \ - --experiment-name cogview-base-long \ - --img-tokenizer-num-tokens 8192 \ - --dataset-type CompactBinaryDataset \ + --experiment-name pretrain-gpt2-cogview-test \ + --tokenizer-type cogview \ + --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ --model-parallel-size ${MP_SIZE} \ - --num-layers 48 \ - --hidden-size 2560 \ - --num-attention-heads 40 \ - --train-iters 300000 \ + --mode pretrain \ + --num-layers 12 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --train-iters 200000 \ --resume-dataloader \ - --train-data ${full_data} \ + --train-data ${small_data} \ --split 949,50,1 \ --distributed-backend nccl \ --lr-decay-style cosine \ --warmup .1 \ --checkpoint-activations \ - --deepspeed-activation-checkpointing \ - --max-position-embeddings 1089 \ - --max-memory-length 0 \ + --max-sequence-length 1089 \ --sandwich-ln \ - --txt-loss-scale 0.1 \ - --sparse-type cuda_2d \ --fp16 \ --save-interval 2000 \ - --no-load-optim \ - --no-save-optim \ --eval-interval 1000 \ - --save $main_dir/data/checkpoints \ - --fast-load \ - --load data/checkpoints/cogview-base \ - --finetune + --save $main_dir/checkpoints \ " - -# --finetune - # --save $main_dir/data/checkpoints \ - # --restart-iter 199000 - - - + # --load pretrained/cogview/cogview-base gpt_options="${gpt_options} diff --git a/scripts/cuda_2d_text2image.sh b/scripts_old/cuda_2d_text2image.sh similarity index 100% rename from scripts/cuda_2d_text2image.sh rename to scripts_old/cuda_2d_text2image.sh diff --git a/scripts/image2text.sh b/scripts_old/image2text.sh similarity index 100% rename from scripts/image2text.sh rename to scripts_old/image2text.sh diff --git a/scripts/low_level_super_resolution.sh b/scripts_old/low_level_super_resolution.sh similarity index 100% rename from scripts/low_level_super_resolution.sh rename to scripts_old/low_level_super_resolution.sh diff --git a/scripts/post_selection.sh b/scripts_old/post_selection.sh similarity index 100% rename from scripts/post_selection.sh rename to scripts_old/post_selection.sh diff --git a/scripts_old/pretrain_multiple_nodes.sh b/scripts_old/pretrain_multiple_nodes.sh new file mode 100755 index 0000000..f099f80 --- /dev/null +++ b/scripts_old/pretrain_multiple_nodes.sh @@ -0,0 +1,74 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=19 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +# OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0" +OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +# OPTIONS_NCCL="" +# HOST_FILE_PATH="hostfile_single" + +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata" +full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin" + +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name cogview-base-long \ + --img-tokenizer-num-tokens 8192 \ + --dataset-type CompactBinaryDataset \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 48 \ + --hidden-size 2560 \ + --num-attention-heads 40 \ + --train-iters 300000 \ + --resume-dataloader \ + --train-data ${full_data} \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .1 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --max-position-embeddings 1089 \ + --max-memory-length 0 \ + --sandwich-ln \ + --txt-loss-scale 0.1 \ + --sparse-type cuda_2d \ + --fp16 \ + --save-interval 2000 \ + --no-load-optim \ + --no-save-optim \ + --eval-interval 1000 \ + --save $main_dir/data/checkpoints \ + --fast-load \ + --load data/checkpoints/cogview-base \ + --finetune +" + +# --finetune + # --save $main_dir/data/checkpoints \ + # --restart-iter 199000 + + + + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/scripts/pretrain_single_node.sh b/scripts_old/pretrain_single_node.sh similarity index 100% rename from scripts/pretrain_single_node.sh rename to scripts_old/pretrain_single_node.sh diff --git a/scripts/super_resolution.sh b/scripts_old/super_resolution.sh similarity index 100% rename from scripts/super_resolution.sh rename to scripts_old/super_resolution.sh diff --git a/scripts/testnan.sh b/scripts_old/testnan.sh similarity index 100% rename from scripts/testnan.sh rename to scripts_old/testnan.sh diff --git a/scripts/text2image.sh b/scripts_old/text2image.sh similarity index 100% rename from scripts/text2image.sh rename to scripts_old/text2image.sh diff --git a/tokenization/cogview/unified_tokenizer.py b/tokenization/cogview/unified_tokenizer.py index bf66d9a..d72741d 100755 --- a/tokenization/cogview/unified_tokenizer.py +++ b/tokenization/cogview/unified_tokenizer.py @@ -35,7 +35,7 @@ class UnifiedTokenizer(object): ('[EOI2]', 5), ('[EOI3]', 6), ('[ROI1]', 7), # Reference - ('[ROI2]', 8), + ('[ROI2]', 8), # 58200 ('[ROI3]', 9), ('[SEP]', 10), ('[MASK]', 11), diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py index 5519003..cff164c 100644 --- a/training/deepspeed_training.py +++ b/training/deepspeed_training.py @@ -36,11 +36,12 @@ from utils import print_rank_0 from utils import get_sample_writer import mpu -from data_utils import make_loaders, get_tokenizer +from data_utils import make_loaders +from tokenization import get_tokenizer -def main(args, model_cls, forward_step_function, create_dataset_function, init_function=None): +def training_main(args, model_cls, forward_step_function, create_dataset_function, init_function=None): """Main training program.""" hooks = { 'forward_step': forward_step_function, @@ -50,7 +51,6 @@ def main(args, model_cls, forward_step_function, create_dataset_function, init_f torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.enabled = False # Disable CuDNN. - set_random_seed(args.seed) # Random seeds for reproducability. timers = Timers() # Timer. # Experiment Name @@ -61,6 +61,8 @@ def main(args, model_cls, forward_step_function, create_dataset_function, init_f # Pytorch distributed. initialize_distributed(args) + set_random_seed(args.seed) # Random seeds for reproducability. + # init tokenizer tokenizer = get_tokenizer(args) # Data stuff. @@ -71,7 +73,7 @@ def main(args, model_cls, forward_step_function, create_dataset_function, init_f # Config model IO if args.load is not None: - args.iteration = load_checkpoint(model, optimizer, args) + args.iteration = load_checkpoint(model, args) # if we don't load optim_states, filelock is no more needed. # with FileLock("/root/checkpoint_lock", timeout=-1): # args.iteration = load_checkpoint(model, optimizer, args) @@ -82,7 +84,7 @@ def main(args, model_cls, forward_step_function, create_dataset_function, init_f torch.distributed.barrier() # initialize lr scheduler - lr_scheduler = get_learning_rate_scheduler(optimizer, args, args.iteration) + lr_scheduler = get_learning_rate_scheduler(optimizer, args.iteration, args) summary_writer = None if torch.distributed.get_rank() == 0: @@ -217,7 +219,7 @@ def get_optimizer_param_groups(model): return param_groups def get_learning_rate_scheduler(optimizer, iteration, args, - auto_warmup_steps=50, auto_warmup_rate=0.05): + auto_warmup_steps=100, auto_warmup_rate=0.05): """Build the learning rate scheduler.""" # Add linear learning rate scheduler. @@ -226,10 +228,9 @@ def get_learning_rate_scheduler(optimizer, iteration, args, else: num_iters = args.train_iters num_iters = max(1, num_iters) - if args.mode == 'pretrain': - init_step = max(iteration-auto_warmup_steps, 0) - elif args.mode == 'finetune': - init_step = 0 + init_step = max(iteration-auto_warmup_steps, 0) + if args.mode == 'pretrain' and iteration == 0: + auto_warmup_steps = 0 # If init_step <= current_steps <= init_step + auto_warmup_steps, # lr = auto_warmup_rate * args.lr. # This overrides other rules. @@ -335,7 +336,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, # Check nan or inf in forward, preventing it from interfering loss scaler, # and all reduce metrics by the way - loss_checker = lm_loss.detach().item() + loss_checker = lm_loss.detach() for name in metrics: metrics[name] = metrics[name].detach() torch.distributed.all_reduce(metrics[name].data) diff --git a/training/learning_rates.py b/training/learning_rates.py index fd325c7..ac984f3 100755 --- a/training/learning_rates.py +++ b/training/learning_rates.py @@ -47,9 +47,9 @@ class AnnealingLR(_LRScheduler): return float(self.start_lr) * self.num_iters / self.warmup_iter else: if self.decay_style == self.DECAY_STYLES[0]: - return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/real_end_iter) + return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter) elif self.decay_style == self.DECAY_STYLES[1]: - decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / real_end_iter) + decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / self.end_iter) return self.start_lr / self.decay_ratio * ( (math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1) elif self.decay_style == self.DECAY_STYLES[2]: diff --git a/training/model_io.py b/training/model_io.py index fbc2254..f7527f9 100644 --- a/training/model_io.py +++ b/training/model_io.py @@ -108,7 +108,7 @@ def get_checkpoint_iteration(args): return iteration, release, True -def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=True): +def load_checkpoint(model, args): """Load a model checkpoint.""" iteration, release, success = get_checkpoint_iteration(args) @@ -137,7 +137,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states= else: # new params assert all(name.find('mixins')>0 for name in missing_keys) module.reinit() # initialize mixins - model.optimizer.refresh_fp32_params() # restore fp32 weights from module + model.optimizer.refresh_fp32_params() # restore fp32 weights from module # Iterations. if args.mode == 'finetune': diff --git a/utils.py b/utils.py index 8ae815c..2430bc2 100755 --- a/utils.py +++ b/utils.py @@ -21,9 +21,6 @@ import time import numpy as np import torch -from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP -import mpu -import model from tensorboardX import SummaryWriter SUMMARY_WRITER_DIR_NAME = 'runs' @@ -132,311 +129,3 @@ def report_memory(name): torch.cuda.memory_reserved() / mega_bytes) print_rank_0(string) - -def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): - if release: - d = 'release' - else: - d = '{:d}'.format(iteration) - if zero: - dp_rank = mpu.get_data_parallel_rank() - d += '_zero_dp_rank_{}'.format(dp_rank) - return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) - - -def ensure_directory_exists(filename): - dirname = os.path.dirname(filename) - if not os.path.exists(dirname): - os.makedirs(dirname) - - -def get_checkpoint_tracker_filename(checkpoints_path): - return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') - - -def save_zero_checkpoint(args, iteration, optimizer): - zero_sd = {'iteration': iteration, - 'optimizer_state_dict': optimizer.state_dict()} - zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True) - ensure_directory_exists(zero_checkpoint_name) - torch.save(zero_sd, zero_checkpoint_name) - print(' successfully saved {}'.format(zero_checkpoint_name)) - - -def save_checkpoint(iteration, model, optimizer, - lr_scheduler, args): - """Save a model checkpoint.""" - if args.deepspeed: - save_ds_checkpoint(iteration, model, lr_scheduler, args) - else: - # Only rank zer0 of the data parallel writes to the disk. - if isinstance(model, torchDDP): - model = model.module - - if mpu.get_data_parallel_rank() == 0: - checkpoint_name = get_checkpoint_name(args.save, iteration) - print('global rank {} is saving checkpoint at iteration {:7d} to {}'. - format(torch.distributed.get_rank(), iteration, checkpoint_name)) - - sd = {} - sd['iteration'] = iteration - sd['module'] = model.state_dict() - - # Optimizer stuff. - if not args.no_save_optim: - if optimizer is not None: - sd['optimizer'] = optimizer.state_dict() - if lr_scheduler is not None: - sd['lr_scheduler'] = lr_scheduler.state_dict() - - # rng states. - if not args.no_save_rng: - sd['random_rng_state'] = random.getstate() - sd['np_rng_state'] = np.random.get_state() - sd['torch_rng_state'] = torch.get_rng_state() - sd['cuda_rng_state'] = torch.cuda.get_rng_state() - sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() - - ensure_directory_exists(checkpoint_name) - torch.save(sd, checkpoint_name) - print(' successfully saved {}'.format(checkpoint_name)) - - # Wait so everyone is done (necessary) - torch.distributed.barrier() - # And update the latest iteration - if torch.distributed.get_rank() == 0: - tracker_filename = get_checkpoint_tracker_filename(args.save) - with open(tracker_filename, 'w') as f: - f.write(str(iteration)) - # Wait so everyone is done (not necessary) - torch.distributed.barrier() - - -def save_ds_checkpoint(iteration, model, lr_scheduler, args): - """Save a model checkpoint.""" - - sd = {} - sd['iteration'] = iteration - if lr_scheduler is not None: - sd['client_lr_scheduler'] = lr_scheduler.state_dict() - # rng states. - if not args.no_save_rng: - sd['random_rng_state'] = random.getstate() - sd['np_rng_state'] = np.random.get_state() - sd['torch_rng_state'] = torch.get_rng_state() - sd['cuda_rng_state'] = torch.cuda.get_rng_state() - sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() - if args.no_save_optim: - save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd) - else: - model.save_checkpoint(args.save, str(iteration), client_state=sd) - -def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True): - - os.makedirs(save_dir, exist_ok=True) - - if tag is None: - tag = f"global_step{model.global_steps}" - - # Ensure tag is a string - tag = str(tag) - - # Ensure checkpoint tag is consistent across ranks - model._checkpoint_tag_validation(tag) - - if model.save_non_zero_checkpoint: - model._create_checkpoint_file(save_dir, tag, False) - model._save_checkpoint(save_dir, tag, client_state=client_state) - - # Save latest checkpoint tag - if save_latest: - with open(os.path.join(save_dir, 'latest'), 'w') as fd: - fd.write(tag) - - return True - - -def get_checkpoint_iteration(args): - # Read the tracker file and set the iteration. - tracker_filename = get_checkpoint_tracker_filename(args.load) - if not os.path.isfile(tracker_filename): - print_rank_0('WARNING: could not find the metadata file {} '.format( - tracker_filename)) - print_rank_0(' will not load any checkpoints and will start from ' - 'random') - return 0, False, False - iteration = 0 - release = False - with open(tracker_filename, 'r') as f: - metastring = f.read().strip() - try: - iteration = int(metastring) - except ValueError: - release = metastring == 'release' - if not release: - print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( - tracker_filename)) - exit() - - assert iteration > 0 or release, 'error parsing metadata file {}'.format( - tracker_filename) - - return iteration, release, True - - -def extend_position_embedding(weight, length): - ori_length, hidden_size = weight.shape - assert length % ori_length == 0 - position_embeddings = weight.expand(length // ori_length, -1, -1).reshape(length, hidden_size) - return position_embeddings - -def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=True): - """Load a model checkpoint.""" - - iteration, release, success = get_checkpoint_iteration(args) - - if not success: - return 0 - - if args.deepspeed: - - checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim, load_module_strict=not args.finetune) - # if args.finetune: - # model.module.module.init_plus_from_old() - if (args.finetune or args.no_load_optim) and model.zero_optimization(): - model.optimizer.refresh_fp32_params() - if "client_lr_scheduler" in sd and not args.finetune: - lr_scheduler.load_state_dict(sd["client_lr_scheduler"]) - print_rank_0("Load lr scheduler state") - if checkpoint_name is None: - if mpu.get_data_parallel_rank() == 0: - print("Unable to load checkpoint.") - return iteration - - else: - - # Checkpoint. - checkpoint_name = get_checkpoint_name(args.load, iteration, release) - - if mpu.get_data_parallel_rank() == 0: - print('global rank {} is loading checkpoint {}'.format( - torch.distributed.get_rank(), checkpoint_name)) - - # Load the checkpoint. - sd = torch.load(checkpoint_name, map_location='cpu') - - if isinstance(model, torchDDP): - model = model.module - - # Model. - try: - model.load_state_dict(sd['module']) - except KeyError: - print_rank_0('A metadata file exists but unable to load model ' - 'from checkpoint {}, exiting'.format(checkpoint_name)) - exit() - - # Optimizer. - if not release and not args.finetune and not args.no_load_optim: - try: - if optimizer is not None and load_optimizer_states: - optimizer.load_state_dict(sd['optimizer']) - if lr_scheduler is not None: - lr_scheduler.load_state_dict(sd['lr_scheduler']) - except KeyError: - print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' - 'Specify --no-load-optim or --finetune to prevent ' - 'attempting to load the optimizer ' - 'state.'.format(checkpoint_name)) - exit() - - # Iterations. - if args.finetune or release: - iteration = 0 - else: - try: - iteration = sd['iteration'] - except KeyError: - try: # Backward compatible with older checkpoints - iteration = sd['total_iters'] - except KeyError: - print_rank_0('A metadata file exists but Unable to load iteration ' - ' from checkpoint {}, exiting'.format(checkpoint_name)) - exit() - - # rng states. - if not release and not args.finetune and not args.no_load_rng: - try: - random.setstate(sd['random_rng_state']) - np.random.set_state(sd['np_rng_state']) - torch.set_rng_state(sd['torch_rng_state']) - torch.cuda.set_rng_state(sd['cuda_rng_state']) - mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) - except KeyError: - print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' - 'Specify --no-load-rng or --finetune to prevent ' - 'attempting to load the random ' - 'state.'.format(checkpoint_name)) - exit() - - if mpu.get_data_parallel_rank() == 0: - print(' successfully loaded {}'.format(checkpoint_name)) - del sd - return iteration - - -def load_weights(src, dst, dst2src=False): - """ - Loads weights from src to dst via in place copy. - src is a huggingface gpt2model, while dst is one of our models. - dst2src=True loads parameters from our models into huggingface's. - ^dst2src is still untested - """ - conv_layer = 'Conv1D' in str(type(src)) - for n, p in src.named_parameters(): - if dst2src: - data = dst._parameters[n].data - load = p.data - else: - data = p.data - load = dst._parameters[n].data - if conv_layer and 'weight' in n: - data = data.t().contiguous() - load.copy_(data) - - -# dst._parameters[n].data.copy_(data) - -def load_mlp(our, oai, dst2src=False): - load_weights(oai.c_fc, our.dense_h_to_4h, dst2src) - load_weights(oai.c_proj, our.dense_4h_to_h, dst2src) - - -def load_attention(our, oai, dst2src=False): - load_weights(oai.c_attn, our.query_key_value, dst2src) - load_weights(oai.c_proj, our.dense, dst2src) - - -def load_transformer_layer(our, oai, dst2src=False): - load_weights(oai.ln_1, our.input_layernorm, dst2src) - load_weights(oai.ln_2, our.post_attention_layernorm, dst2src) - load_mlp(our.mlp, oai.mlp, dst2src) - load_attention(our.attention, oai.attn, dst2src) - - -def move_weights(our, oai, dst2src=False): - """ - Loads weights from `oai` to `our` via in place copy. - `oai` is a huggingface gpt2model, while `our` is one of our models. - dst2src=True loads parameters from our models into huggingface's. - ^dst2src=True is still untested - """ - # while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)): - # our=our.module - transformer_model = oai.transformer - load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src) - load_weights(transformer_model.wte, our.word_embeddings, dst2src) - load_weights(transformer_model.wpe, our.position_embeddings, dst2src) - - for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): - load_transformer_layer(our_layer, oai_layer, dst2src) \ No newline at end of file -- GitLab