From 7ea243f037ea637c876c1c2308eae5b2fd5706c4 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Thu, 7 Oct 2021 18:37:53 +0000
Subject: [PATCH] test load model, perform normally

---
 .gitignore                                    |   2 +-
 arguments.py                                  | 235 ++---------
 data_utils/configure_data.py                  |   3 +-
 model/__init__.py                             |  18 -
 model/base_model.py                           |   5 +-
 model/cuda2d_model.py                         |  16 +-
 move_weights.py                               | 105 +++++
 mpu/__init__.py                               |   5 -
 mpu/layers.py                                 |   2 -
 mpu/random.py                                 | 386 ------------------
 mpu/transformer.py                            |  12 +-
 pretrain_cogview2.py                          | 170 ++++++++
 pretrain_gpt2.py                              |  16 +-
 scripts/ds_config_zero.json                   |   2 +-
 scripts/pretrain_cogview2.sh                  |  59 +++
 scripts/pretrain_multiple_nodes.sh            |  46 +--
 .../cuda_2d_text2image.sh                     |   0
 {scripts => scripts_old}/image2text.sh        |   0
 .../low_level_super_resolution.sh             |   0
 {scripts => scripts_old}/post_selection.sh    |   0
 scripts_old/pretrain_multiple_nodes.sh        |  74 ++++
 .../pretrain_single_node.sh                   |   0
 {scripts => scripts_old}/super_resolution.sh  |   0
 {scripts => scripts_old}/testnan.sh           |   0
 {scripts => scripts_old}/text2image.sh        |   0
 tokenization/cogview/unified_tokenizer.py     |   2 +-
 training/deepspeed_training.py                |  23 +-
 training/learning_rates.py                    |   4 +-
 training/model_io.py                          |   4 +-
 utils.py                                      | 311 --------------
 30 files changed, 516 insertions(+), 984 deletions(-)
 create mode 100644 move_weights.py
 delete mode 100755 mpu/random.py
 create mode 100755 pretrain_cogview2.py
 create mode 100755 scripts/pretrain_cogview2.sh
 rename {scripts => scripts_old}/cuda_2d_text2image.sh (100%)
 rename {scripts => scripts_old}/image2text.sh (100%)
 rename {scripts => scripts_old}/low_level_super_resolution.sh (100%)
 rename {scripts => scripts_old}/post_selection.sh (100%)
 create mode 100755 scripts_old/pretrain_multiple_nodes.sh
 rename {scripts => scripts_old}/pretrain_single_node.sh (100%)
 rename {scripts => scripts_old}/super_resolution.sh (100%)
 rename {scripts => scripts_old}/testnan.sh (100%)
 rename {scripts => scripts_old}/text2image.sh (100%)

diff --git a/.gitignore b/.gitignore
index c313862..45ed55c 100755
--- a/.gitignore
+++ b/.gitignore
@@ -9,7 +9,7 @@ _cache*
 .vscode/
 samples/
 hostfile
-pretrained/checkpoints
+pretrained/
 *.png
 *.jpg
 *.jpeg
diff --git a/arguments.py b/arguments.py
index 7ea0432..7344533 100755
--- a/arguments.py
+++ b/arguments.py
@@ -39,60 +39,20 @@ def add_model_config_args(parser):
                        help='layer norm epsilon')
     group.add_argument('--hidden-dropout', type=float, default=0.1,
                        help='dropout probability for hidden state transformer')
-    group.add_argument('--max-position-embeddings', type=int, default=512,
+    group.add_argument('--max-sequence-length', type=int, default=512,
                        help='maximum number of position embeddings to use')
-    group.add_argument('--vocab-size', type=int, default=30522,
+    group.add_argument('--vocab-size', type=int, default=0,
                        help='vocab size to use for non-character-level '
                             'tokenization. This value will only be used when '
                             'creating a tokenizer')
-    group.add_argument('--deep-init', action='store_true',
-                       help='initialize bert model similar to gpt2 model.'
-                            'scales initialization of projection layers by a '
-                            'factor of 1/sqrt(2N). Necessary to train bert '
-                            'models larger than BERT-Large.')
     group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                        help='Pad the vocab size to be divisible by this value.'
                             'This is added for computational efficieny reasons.')
-    group.add_argument('--cpu-optimizer', action='store_true',
-                       help='Run optimizer on CPU')
-    group.add_argument('--cpu_torch_adam', action='store_true',
-                       help='Use Torch Adam as optimizer on CPU.')
-    
-    group.add_argument('--max-position-embeddings-finetune', type=int, default=-1,
-                       help='maximum number of position embeddings to use in finetune')
     group.add_argument('--sandwich-ln', action='store_true',
                        help='add sandwich ln in cogview.')
     return parser
 
 
-def add_fp16_config_args(parser):
-    """Mixed precision arguments."""
-
-    group = parser.add_argument_group('fp16', 'fp16 configurations')
-
-    group.add_argument('--fp16', action='store_true',
-                       help='Run model in fp16 mode')
-    group.add_argument('--fp32-embedding', action='store_true',
-                       help='embedding in fp32')
-    group.add_argument('--fp32-layernorm', action='store_true',
-                       help='layer norm in fp32')
-    group.add_argument('--fp32-tokentypes', action='store_true',
-                       help='embedding token types in fp32')
-    group.add_argument('--fp32-allreduce', action='store_true',
-                       help='all-reduce in fp32')
-    group.add_argument('--hysteresis', type=int, default=2,
-                       help='hysteresis for dynamic loss scaling')
-    group.add_argument('--loss-scale', type=float, default=None,
-                       help='Static loss scaling, positive power of 2 '
-                            'values can improve fp16 convergence. If None, dynamic'
-                            'loss scaling is used.')
-    group.add_argument('--loss-scale-window', type=float, default=1000,
-                       help='Window over which to raise/lower dynamic scale')
-    group.add_argument('--min-scale', type=float, default=1,
-                       help='Minimum loss scale for dynamic loss scale')
-
-    return parser
-
 
 def add_training_args(parser):
     """Training arguments."""
@@ -110,10 +70,6 @@ def add_training_args(parser):
                             'with larger models and sequences')
     group.add_argument('--checkpoint-num-layers', type=int, default=1,
                        help='chunk size (number of layers) for checkpointing')
-    group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
-                       help='uses activation checkpointing from deepspeed')
-    group.add_argument('--clip-grad', type=float, default=1.0,
-                       help='gradient clipping')
     group.add_argument('--train-iters', type=int, default=1000000,
                        help='total number of iterations to train over all training runs')
     group.add_argument('--log-interval', type=int, default=50,
@@ -123,16 +79,6 @@ def add_training_args(parser):
     group.add_argument('--summary-dir', type=str, default="", help="The directory to store the summary")
     group.add_argument('--seed', type=int, default=1234,
                        help='random seed')
-    group.add_argument('--img-tokenizer-path', type=str, default=None,
-                       help='The checkpoint file path of image tokenizer.')
-    group.add_argument('--img-tokenizer-num-tokens', type=int, default=None,
-                       help='The num tokens of image tokenizer. ONLY use for pretraining with img-tokenizer UNKNOW.')
-    # Batch prodecuer arguments
-    group.add_argument('--reset-position-ids', action='store_true',
-                       help='Reset posistion ids after end-of-document token.')
-    group.add_argument('--reset-attention-mask', action='store_true',
-                       help='Reset self attention maske after '
-                            'end-of-document token.')
 
     # Learning rate.
     group.add_argument('--lr-decay-iters', type=int, default=None,
@@ -147,31 +93,32 @@ def add_training_args(parser):
     group.add_argument('--warmup', type=float, default=0.01,
                        help='percentage of data to warmup on (.01 = 1% of all '
                             'training iters). Default 0.01')
-    group.add_argument('--restart-iter', type=int, default=0,
-                       help='restart with warmup from this iteration.')
     # model checkpointing
     group.add_argument('--save', type=str, default=None,
                        help='Output directory to save checkpoints to.')
+    group.add_argument('--load', type=str, default=None,
+                       help='Path to a directory containing a model checkpoint.')
     group.add_argument('--save-interval', type=int, default=5000,
                        help='number of iterations between saves')
-    group.add_argument('--no-save-optim', action='store_true',
-                       help='Do not save current optimizer.')
+    # group.add_argument('--no-save-optim', action='store_true',
+    #                    help='Do not save current optimizer.')
+    # group.add_argument('--no-load-optim', action='store_true',
+    #                    help='Do not load optimizer when loading checkpoint.')
     group.add_argument('--no-save-rng', action='store_true',
                        help='Do not save current rng state.')
-    group.add_argument('--load', type=str, default=None,
-                       help='Path to a directory containing a model checkpoint.')
-    group.add_argument('--no-load-optim', action='store_true',
-                       help='Do not load optimizer when loading checkpoint.')
     group.add_argument('--no-load-rng', action='store_true',
                        help='Do not load rng state when loading checkpoint.')
-    group.add_argument('--finetune', action='store_true',
-                       help='Load model for finetuning. Do not load optimizer '
-                            'or rng state from checkpoint and set iteration to 0. '
-                            'Assumed when loading a release checkpoint.')
+    group.add_argument('--mode', type=str,
+                       default='pretrain',
+                       choices=['pretrain',
+                                'finetune',
+                                'inference'
+                                ],
+                       help='what type of task to use, will influence auto-warmup, exp name, iteration')
     group.add_argument('--resume-dataloader', action='store_true',
                        help='Resume the dataloader when resuming training. '
                             'Does not apply to tfrecords dataloader, try resuming'
-                            'with a different seed in this case.')
+                            'with a different seed in this case.') 
     # distributed training args
     group.add_argument('--distributed-backend', default='nccl',
                        help='which backend to use for distributed '
@@ -180,11 +127,9 @@ def add_training_args(parser):
     group.add_argument('--local_rank', type=int, default=None,
                        help='local rank passed from distributed launcher')
     
-    # loss scale
-    group.add_argument('--txt-loss-scale', type=float, default=1)
-    group.add_argument('--fast-load', action='store_true',
-                       help='load checkpoints without locks.')
-
+    group.add_argument('--fp16', action='store_true',
+                       help='Run model in fp16 mode')
+    
     return parser
 
 
@@ -243,9 +188,9 @@ def add_data_args(parser):
 
     group.add_argument('--model-parallel-size', type=int, default=1,
                        help='size of the model parallel.')
-    group.add_argument('--shuffle', action='store_true',
-                       help='Shuffle data. Shuffling is deterministic '
-                            'based on seed and current epoch.')
+    # group.add_argument('--shuffle', action='store_true',
+    #                    help='Shuffle data. Shuffling is deterministic '
+    #                         'based on seed and current epoch.')
     group.add_argument('--train-data', nargs='+', default=None,
                        help='Whitespace separated filenames or corpora names '
                             'for training.')
@@ -261,20 +206,6 @@ def add_data_args(parser):
     group.add_argument('--num-workers', type=int, default=2,
                        help="""Number of workers to use for dataloading""")
 
-    group.add_argument('--dataset-type', type=str,
-                       default='TokenizedDataset',
-                       choices=['TokenizedDataset',
-                                'TextCodeDataset',
-                                'CompactBinaryDataset',
-                                'BinaryDataset'
-                                ],
-                       help='what type of dataset to use')
-
-    group.add_argument('--max-memory-length', type=int, default=2048,
-                       help="max memory buffer for attention")
-    group.add_argument('--new-dataset-path', type=str, default=None,
-                       help='The folder we will dynamically check for lmdbs during training.')
-    
     return parser
 
 def add_generation_api_args(parser):
@@ -290,102 +221,59 @@ def add_generation_api_args(parser):
     group.add_argument('--device', default=None)
 
     return parser
-
-def add_sparse_args(parser):
+    
+def add_tokenization_args(parser):
     """sparse attention arguments."""
 
-    group = parser.add_argument_group('Sparse Attention', 'sparse configurations')
-    group.add_argument('--sparse-type', type=str, default='standard',
-                       choices=['standard', 'torch_1d', 'cuda_2d'],
-                       help='whether use sparse attention.') # TODO: Temporally not using is-sparse==2 (not optimized), use 0 for inference.
-    # for torch_1d
-    group.add_argument("--query-window", type=int, default=128)
-    group.add_argument("--key-window-times", type=int, default=6)
-    group.add_argument("--num-pivot", type=int, default=768)
-    # for cuda_2d
-    group.add_argument("--kernel-size", type=int, default=9)
-    group.add_argument("--kernel-size2", type=int, default=7)
-    group.add_argument("--layout", type=str, default='64,1088,5184')
+    group = parser.add_argument_group('Tokenization', 'tokenization configurations')
+    group.add_argument('--tokenizer-type', type=str, default='fake', help='type name of tokenizer')
+
+    group.add_argument('--img-tokenizer-path', type=str, default=None,
+                       help='The checkpoint file path of image tokenizer.')
     return parser
 
-def make_sparse_config(args):
-    args.layout = [int(x) for x in args.layout.split(',')]
-    sparse_config = argparse.Namespace(sparse_type=args.sparse_type)
-    sparse_config.layout = args.layout
-    if args.sparse_type == 'standard':
-        pass
-    if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation':
-        sparse_config.kernel_size = args.kernel_size
-        sparse_config.kernel_size2 = args.kernel_size2
-    elif args.sparse_type == 'torch_1d':
-        raise NotImplementedError
-    args.sparse_config = sparse_config
-
-def get_args():
+
+    
+def get_args(args_list=None):
     """Parse all the args."""
 
-    parser = argparse.ArgumentParser(description='PyTorch CogView Model')
+    parser = argparse.ArgumentParser(description='Swiss Army Transformer')
     parser = add_model_config_args(parser)
-    parser = add_fp16_config_args(parser)
     parser = add_training_args(parser)
     parser = add_evaluation_args(parser)
-    parser = add_text_generate_args(parser)
     parser = add_data_args(parser)
+    parser = add_tokenization_args(parser)
+    parser = add_text_generate_args(parser)
     parser = add_generation_api_args(parser)
-    parser = add_sparse_args(parser)
 
     # Include DeepSpeed configuration arguments
     parser = deepspeed.add_config_arguments(parser)
 
-    args = parser.parse_args()
-    make_sparse_config(args)
+    args = parser.parse_args(args_list)
 
     if not args.train_data:
         print('WARNING: No training data specified')
-    elif args.sparse_type == 'torch_1d' and (args.max_position_embeddings - 1) % args.query_window != 0:
-        raise ValueError('During sparse training, the sequence length must be exactly divided by window_size.') 
 
     args.cuda = torch.cuda.is_available()
 
     args.rank = int(os.getenv('RANK', '0'))
     args.world_size = int(os.getenv("WORLD_SIZE", '1'))
-    if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
-        mpi_define_env(args)
-    elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
-        # We are using (OpenMPI) mpirun for launching distributed data parallel processes
-        local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
-        local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))
-
-        # Possibly running with Slurm
-        num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1'))
-        nodeid = int(os.getenv('SLURM_NODEID', '0'))
-
-        args.local_rank = local_rank
-        args.rank = nodeid * local_size + local_rank
-        args.world_size = num_nodes * local_size
+    
 
     args.model_parallel_size = min(args.model_parallel_size, args.world_size)
     if args.rank == 0:
         print('using world size: {} and model-parallel size: {} '.format(
             args.world_size, args.model_parallel_size))
 
-    args.dynamic_loss_scale = False
-    if args.loss_scale is None:
-        args.dynamic_loss_scale = True
-        if args.rank == 0:
-            print(' > using dynamic loss scaling')
-
-    # The args fp32_* or fp16_* meant to be active when the
-    # args fp16 is set. So the default behaviour should all
-    # be false.
-    if not args.fp16:
-        args.fp32_embedding = False
-        args.fp32_tokentypes = False
-        args.fp32_layernorm = False
-
     if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None:
         with open(args.deepspeed_config) as file:
             deepspeed_config = json.load(file)
+        if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
+            args.fp16 = True
+        else:
+            args.fp16 = False
+        if args.checkpoint_activations:
+            args.deepspeed_activation_checkpointing = True
         if "train_micro_batch_size_per_gpu" in deepspeed_config:
             args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
         if "gradient_accumulation_steps" in deepspeed_config:
@@ -399,40 +287,3 @@ def get_args():
     return args
 
 
-def mpi_define_env(args):
-    ''' For training CogView via MPI to setup the connection.
-        Omit this function if use the basic deepspeed pdsh runner. 
-    '''
-    from mpi4py import MPI
-    import subprocess
-    comm = MPI.COMM_WORLD
-    rank = comm.Get_rank()
-    world_size = comm.Get_size()
-
-    master_addr = None
-    if rank == 0:
-        hostname_cmd = ["hostname -I"]
-        result = subprocess.check_output(hostname_cmd, shell=True)
-        master_addr = result.decode('utf-8').split()[0]
-    master_addr = comm.bcast(master_addr, root=0)
-
-    # Determine local rank by assuming hostnames are unique
-    proc_name = MPI.Get_processor_name()
-    all_procs = comm.allgather(proc_name)
-    local_rank = sum([i == proc_name for i in all_procs[:rank]])
-
-    os.environ['RANK'] = str(rank)
-    os.environ['WORLD_SIZE'] = str(world_size)
-    args.local_rank = local_rank
-    args.world_size = world_size
-    args.rank = rank
-    os.environ['MASTER_ADDR'] = master_addr
-    os.environ['MASTER_PORT'] = "29500" # TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
-
-    print(
-        "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
-        .format(os.environ['RANK'],
-                args.local_rank,
-                os.environ['WORLD_SIZE'],
-                os.environ['MASTER_ADDR'],
-                os.environ['MASTER_PORT']))
diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py
index 6556673..95c3bd1 100755
--- a/data_utils/configure_data.py
+++ b/data_utils/configure_data.py
@@ -52,7 +52,7 @@ def make_data_loader(dataset, batch_size, num_iters, args):
     return data_loader
 
 
-def make_dataset_full(dataset_type, path, split, args, create_dataset_function, **kwargs):
+def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
     """function to create datasets+tokenizers for common options"""
     print('make dataset ...', path)
     if split is None:
@@ -93,7 +93,6 @@ def make_loaders(args, create_dataset_function):
 
     data_set_args = {
         'path': args.train_data,
-        'dataset_type': args.dataset_type,
         'split': split,
     }
 
diff --git a/model/__init__.py b/model/__init__.py
index beb1271..e69de29 100755
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -1,18 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization
-from .gpt2_modeling import GPT2Model
-
diff --git a/model/base_model.py b/model/base_model.py
index 7b7769c..2b74fb0 100644
--- a/model/base_model.py
+++ b/model/base_model.py
@@ -12,7 +12,6 @@ import sys
 import math
 import random
 import torch
-from functools import partial
 
 from mpu import BaseTransformer
 
@@ -28,7 +27,7 @@ class BaseModel(torch.nn.Module):
                 vocab_size=args.vocab_size,
                 hidden_size=args.hidden_size,
                 num_attention_heads=args.num_attention_heads,
-                max_sequence_length=args.max_position_embeddings,
+                max_sequence_length=args.max_sequence_length,
                 embedding_dropout_prob=args.hidden_dropout,
                 attention_dropout_prob=args.attention_dropout,
                 output_dropout_prob=args.hidden_dropout,
@@ -58,7 +57,7 @@ class BaseModel(torch.nn.Module):
         hooks = {}
         for name in names:
             if hasattr(self, name):
-                hooks[name] = partial(getattr(self, name), self)
+                hooks[name] = getattr(self, name)
         return hooks
 
     def disable_untrainable_params(self):
diff --git a/model/cuda2d_model.py b/model/cuda2d_model.py
index a9e1175..661e49f 100644
--- a/model/cuda2d_model.py
+++ b/model/cuda2d_model.py
@@ -20,8 +20,8 @@ from .mixins import PositionEmbeddingMixin, AttentionMixin
 
 from mpu.transformer import split_tensor_along_last_dim
 from mpu.local_attention_function import f_similar, f_weighting
-from mpu.random import get_cuda_rng_tracker
 from mpu.utils import sqrt
+from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
 
 
 class Cuda2dModel(BaseModel):
@@ -88,7 +88,19 @@ class Cuda2dModel(BaseModel):
         output_1 = dense_plus(context_layer1)
         output = torch.cat((output_0, output_1), dim=1)
         
-        return output
+        return output, None
+    
+    def disable_untrainable_params(self):
+        self.transformer.requires_grad_(False)
+    
+    @classmethod
+    def add_model_specific_args(cls, parser):
+        group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
+        group.add_argument("--kernel-size", type=int, default=9)
+        group.add_argument("--kernel-size2", type=int, default=7)
+        group.add_argument("--layout", type=str, default='64,1088,5184')
+        group.add_argument("--new-sequence-length", type=int, default=5185)
+        return parser
 
 def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, **kwargs):
     '''
diff --git a/move_weights.py b/move_weights.py
new file mode 100644
index 0000000..7d41f2b
--- /dev/null
+++ b/move_weights.py
@@ -0,0 +1,105 @@
+# %%
+import torch
+old = torch.load('pretrained/cogview/cogview-base/142000/mp_rank_00_model_states.pt', map_location='cpu')
+
+old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight']
+del old['module']['word_embeddings.weight']
+
+from model.base_model import BaseModel
+import argparse
+import os
+args = argparse.Namespace(
+    num_layers=48,
+    vocab_size=58240,
+    hidden_size=2560,
+    num_attention_heads=40,
+    max_sequence_length=1089,
+    hidden_dropout=0.1,
+    attention_dropout=0.1,
+    checkpoint_activations=True,
+    checkpoint_num_layers=1,
+    sandwich_ln=True,
+    model_parallel_size=1,
+    world_size=1,
+    rank=0
+    )
+init_method = 'tcp://'
+master_ip = os.getenv('MASTER_ADDR', 'localhost')
+master_port = os.getenv('MASTER_PORT', '6000')
+init_method += master_ip + ':' + master_port
+torch.distributed.init_process_group(
+        backend='nccl',
+        world_size=args.world_size, rank=args.rank,init_method=init_method)
+import mpu
+    # Set the model-parallel / data-parallel communicators.
+mpu.initialize_model_parallel(args.model_parallel_size)
+print('bg')
+model = BaseModel(args)
+missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False)
+torch.save(old, 'pretrained/cogview/cogview-base/142000/mp_rank_00_model_states.pt')
+# %%
+import torch
+old = torch.load('pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt', map_location='cpu')
+old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight']
+del old['module']['word_embeddings.weight']
+#%%
+from model.cuda2d_model import Cuda2dModel
+import argparse
+import os
+args = argparse.Namespace(
+    num_layers=48,
+    vocab_size=58240,
+    hidden_size=2560,
+    num_attention_heads=40,
+    max_sequence_length=1089,
+    hidden_dropout=0.1,
+    attention_dropout=0.1,
+    checkpoint_activations=True,
+    checkpoint_num_layers=1,
+    sandwich_ln=True,
+    model_parallel_size=1,
+    world_size=1,
+    rank=0,
+    new_sequence_length=1089+4096,
+    layout='0,64,1088,5184',
+    kernel_size=9,
+    kernel_size2=7
+    )
+# %%
+init_method = 'tcp://'
+master_ip = os.getenv('MASTER_ADDR', 'localhost')
+master_port = os.getenv('MASTER_PORT', '6000')
+init_method += master_ip + ':' + master_port
+torch.distributed.init_process_group(
+        backend='nccl',
+        world_size=args.world_size, rank=args.rank,init_method=init_method)
+import mpu
+    # Set the model-parallel / data-parallel communicators.
+mpu.initialize_model_parallel(args.model_parallel_size)
+print('bg')
+#%%
+model = Cuda2dModel(args)
+
+#%%
+old['module']['mixins.0.position_embeddings.weight'] = old['module']['transformer.position_embeddings_plus.weight']
+del old['module']['transformer.position_embeddings_plus.weight']
+
+for i in range(48):
+    old['module'][f'mixins.1.query_key_value.{i}.weight'] = \
+        old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight']
+    del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight']
+    old['module'][f'mixins.1.query_key_value.{i}.bias'] = \
+        old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias']
+    del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias']
+    old['module'][f'mixins.1.dense.{i}.weight'] = \
+        old['module'][f'transformer.layers.{i}.attention.dense_plus.weight']
+    del old['module'][f'transformer.layers.{i}.attention.dense_plus.weight']
+    old['module'][f'mixins.1.dense.{i}.bias'] = \
+        old['module'][f'transformer.layers.{i}.attention.dense_plus.bias']
+    del old['module'][f'transformer.layers.{i}.attention.dense_plus.bias']
+# %%
+missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False)
+
+# %%
+torch.save(old, 'pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt')
+# %%
diff --git a/mpu/__init__.py b/mpu/__init__.py
index 358f997..c2ebd2f 100755
--- a/mpu/__init__.py
+++ b/mpu/__init__.py
@@ -40,10 +40,5 @@ from .mappings import gather_from_model_parallel_region
 from .mappings import reduce_from_model_parallel_region
 from .mappings import scatter_to_model_parallel_region
 
-from .random import checkpoint
-from .random import partition_activations_in_checkpoint
-from .random import get_cuda_rng_tracker
-from .random import model_parallel_cuda_manual_seed
-
 from .transformer import BaseTransformer
 from .transformer import LayerNorm
diff --git a/mpu/layers.py b/mpu/layers.py
index 2739bd9..72ec990 100755
--- a/mpu/layers.py
+++ b/mpu/layers.py
@@ -33,9 +33,7 @@ from .mappings import copy_to_model_parallel_region
 from .mappings import gather_from_model_parallel_region
 from .mappings import reduce_from_model_parallel_region
 from .mappings import scatter_to_model_parallel_region
-from .random import get_cuda_rng_tracker
 from .utils import divide
-from .utils import split_tensor_along_last_dim
 from .utils import VocabUtility
 
 
diff --git a/mpu/random.py b/mpu/random.py
deleted file mode 100755
index 69f4cec..0000000
--- a/mpu/random.py
+++ /dev/null
@@ -1,386 +0,0 @@
-# coding=utf-8
-#Modified by Samyam Rajbhandari
-#Used to partition the activations stored for backward propagation
-#Therefore reduces the memory consumption
-
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-# Parts of the code here are adapted from PyTorch
-# repo: https://github.com/pytorch/pytorch
-import contextlib
-import torch.distributed as dist
-import torch
-from torch import _C
-from torch.cuda import _lazy_call, device as device_ctx_manager
-#from torch.utils.checkpoint import detach_variable
-
-
-import torch.distributed as dist
-PARTITION_ACTIVATIONS = False
-PA_CORRECTNESS_TEST= False
-
-def see_memory_usage(message, force=False):
-    if not force:
-        return
-    dist.barrier()
-    if dist.get_rank() == 0:
-        print(message)
-        print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes")
-        print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes")
-        print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes")
-        print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes")
-        print(" ")
-        #input("Press Any Key To Continue ..")
-
-
-from .initialize import get_data_parallel_rank
-from .initialize import get_model_parallel_rank
-from .initialize import get_model_parallel_world_size
-from .initialize import get_model_parallel_group
-
-mp_rank = None #get_model_parallel_rank()
-mp_size = None #get_model_parallel_world_size()
-mp_group = None #get_model_parallel_group()
-
-# Default name for the model parallel rng tracker.
-_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
-transport_stream = None 
-cuda_device=None
-def detach_variable(inputs, device=None):
-    if isinstance(inputs, tuple):
-        out = []
-        for inp in inputs:
-            if not isinstance(inp, torch.Tensor):
-                out.append(inp)
-                continue
-
-            requires_grad = inp.requires_grad
-
-            if device is not None:
-                x = inp.to(device=device)
-            else:
-                x = inp
- 
-            x = x.detach()
-            x.requires_grad = requires_grad
-            out.append(x)
-        return tuple(out)
-    else:
-        raise RuntimeError(
-            "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
-
-def _set_cuda_rng_state(new_state, device=-1):
-    """Sets the random number generator state of the current GPU.
-
-    Argumentss:
-        new_state (torch.ByteTensor): The desired state
-    This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
-    with a single change: the input state is not cloned. Cloning caused
-    major performance issues for +4 GPU cases.
-    """
-    if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
-        # older PyTorch
-        def cb():
-            with device_ctx_manager(device):
-                _C._cuda_setRNGState(new_state)
-    else:
-        # newer PyTorch
-        if device == -1:
-            device = torch.device('cuda')
-        elif isinstance(device, str):
-            device = torch.device(device)
-        elif isinstance(device, int):
-            device = torch.device('cuda', device)
-
-        def cb():
-            idx = device.index
-            if idx is None:
-                idx = torch.cuda.current_device()
-            default_generator = torch.cuda.default_generators[idx]
-            default_generator.set_state(new_state)
-
-    _lazy_call(cb)
-
-
-
-class CudaRNGStatesTracker:
-    """Tracker for the cuda RNG states.
-
-    Using the `add` method, a cuda rng state is initialized based on
-    the input `seed` and is assigned to `name`. Later, by forking the
-    rng state, we can perform operations and return to our starting
-    cuda state.
-    """
-    def __init__(self):
-        # Map from a string name to the cuda rng state.
-        self.states_ = {}
-        # Seeds are just for book keeping and ensure no seed is set twice.
-        self.seeds_ = set()
-
-    def reset(self):
-        """Set to the initial state (no tracker)."""
-        self.states_ = {}
-        self.seeds_ = set()
-
-    def get_states(self):
-        """Get rng states. Copy the dictionary so we have direct
-        pointers to the states, not just a pointer to the dictionary."""
-        states = {}
-        for name in self.states_:
-            states[name] = self.states_[name]
-        return states
-
-    def set_states(self, states):
-        """Set the rng states. For efficiency purposes, we do not check
-        the size of seed for compatibility."""
-        self.states_ = states
-
-    def add(self, name, seed):
-        """Track the rng state."""
-        # Check seed is not already used.
-        if seed in self.seeds_:
-            raise Exception('seed {} already exists'.format(seed))
-        self.seeds_.add(seed)
-        # Check that state is not already defined.
-        if name in self.states_:
-            raise Exception('cuda rng state {} already exists'.format(name))
-        # Get the current rng state.
-        orig_rng_state = torch.cuda.get_rng_state()
-        # Set the new state and store it.
-        torch.cuda.manual_seed(seed)
-        self.states_[name] = torch.cuda.get_rng_state()
-        # Reset rng state to what it was.
-        _set_cuda_rng_state(orig_rng_state)
-
-    @contextlib.contextmanager
-    def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
-        """Fork the cuda rng state, perform operations, and exit with
-        the original state."""
-        # Check if we have added the state
-        if name not in self.states_:
-            raise Exception('cuda rng state {} is not added'.format(name))
-        # Store current rng state.
-        orig_cuda_rng_state = torch.cuda.get_rng_state()
-        # Set rng state to the desired one
-        _set_cuda_rng_state(self.states_[name])
-        # Do the stuff we wanted to do.
-        try:
-            yield
-        finally:
-            # Update the current rng state for later use.
-            self.states_[name] = torch.cuda.get_rng_state()
-            # And set the state to the original state we started with.
-            _set_cuda_rng_state(orig_cuda_rng_state)
-
-
-# RNG tracker object.
-_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
-
-
-def get_cuda_rng_tracker():
-    """Get cuda rng tracker."""
-    return _CUDA_RNG_STATE_TRACKER
-
-
-def model_parallel_cuda_manual_seed(seed):
-    """Initialize model parallel cuda seed.
-
-    This function should be called after the model parallel is
-    initialized. Also, no torch.cuda.manual_seed should be called
-    after this function. Basically, this is replacement for that
-    function.
-    Two set of RNG states are tracked:
-        default state: This is for data parallelism and is the same among a
-                       set of model parallel GPUs but different across
-                       different model paralle groups. This is used for
-                       example for dropout in the non-model-parallel regions.
-        model-parallel state: This state is different among a set of model
-                              parallel GPUs, but the same across data parallel
-                              groups. This is used for example for dropout in
-                              model parallel regions.
-    """
-    # 2718 is just for fun and any POSITIVE value will work.
-    offset = seed + 2718
-    model_parallel_seed = offset + get_model_parallel_rank()
-    # Data parallel gets the original sedd.
-    data_parallel_seed = seed
-
-    if torch.distributed.get_rank() == 0:
-        print('> initializing model parallel cuda seeds on global rank {}, '
-              'model parallel rank {}, and data parallel rank {} with '
-              'model parallel seed: {} and data parallel seed: {}'.format(
-                  torch.distributed.get_rank(), get_model_parallel_rank(),
-                  get_data_parallel_rank(), model_parallel_seed,
-                  data_parallel_seed), flush=True)
-    _CUDA_RNG_STATE_TRACKER.reset()
-    # Set the default state.
-    torch.cuda.manual_seed(data_parallel_seed)
-    # and model parallel state.
-    _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
-                                model_parallel_seed)
-
-
-def get_partition_start(item):
-    global mp_rank, mp_size, mp_group
-    partition_size = get_partition_size(item)
-    start = partition_size * mp_rank
-    return int(start)
-
-def get_partition_size(item):
-    global mp_rank, mp_size, mp_group
-    size = item.numel()
-    partition_size = size/mp_size
-    return int(partition_size)
-    
-def get_full_inputs(tensors):
-    inputs=[]
-    for i in range(int(len(tensors)/2)-1):
-        item = tensors[2 * i]
-        size = tensors[2* i + 1]
-        partition_size = item.numel()
-        tensor_size = partition_size * mp_size
-        flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
-        partitions=[]
-        for i in range(mp_size):
-            part_i = flat_tensor.narrow(0, partition_size * i , partition_size)
-            if i == mp_rank:
-                part_i.copy_(item)
-            partitions.append(part_i)
-        dist.all_gather(partitions,partitions[mp_rank], group=mp_group)
-        input_tensor = flat_tensor.view(list(size.numpy()))
-        item.data=input_tensor.data
-
-        inputs.append(item)
-    inputs.append(tensors[-2])
-        
-    return tuple(inputs)
-
-        
-
-class CheckpointFunction(torch.autograd.Function):
-    """This function is adapted from torch.utils.checkpoint with
-       two main changes:
-           1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
-           2) the states in the model parallel tracker are also properly
-              tracked/set/reset.
-    """
-    @staticmethod
-    def forward(ctx, run_function, *args):
-        ctx.run_function = run_function
-        global mp_rank, mp_size, mp_group
-        if mp_rank is None:
-            mp_rank = get_model_parallel_rank()
-            mp_size = get_model_parallel_world_size()
-            mp_group = get_model_parallel_group()
-
-
-        global cuda_device, transport_stream, PARTITION_ACTIVATIONS
-        if cuda_device is None:
-            if dist.get_rank()  == 0:
-                print(f"Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}")
-            
-            cuda_device = torch.cuda.current_device()
-            #The transport stream is used to overlap the allgather communication for the activations
-            #with the computation in the backward pass
-            transport_stream = torch.cuda.Stream(device=cuda_device)
-
-        if PARTITION_ACTIVATIONS:
-            inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
-            inputs.append(args[-1])
-
-        #just in case something funky is happening such as reuse of inputs
-        inputs_cuda = [item.to(cuda_device) for item in args]
-        
-        # Copy the rng states.
-        ctx.fwd_cpu_rng_state = torch.get_rng_state()
-        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
-        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
-
-        #ctx.save_for_backward(*args)
-        with torch.no_grad():
-            outputs = run_function(*inputs_cuda)
-
-        del inputs_cuda
-        
-        if PARTITION_ACTIVATIONS:
-            new_args = []
-            for arg, inp in zip(args,inputs):                
-                size= torch.tensor(arg.size())
-                arg.data = inp.data
-                new_args.append(arg)
-                new_args.append(size)
-            ctx.save_for_backward(*new_args)
-        else:
-            ctx.save_for_backward(*args)
-        
-        return outputs
-
-    @staticmethod
-    def backward(ctx, *args):
-        if not torch.autograd._is_checkpoint_valid():
-            raise RuntimeError("Checkpointing is not compatible with .grad(), "
-                               "please use .backward() if possible")
-        
-        global cuda_device, transport_stream, PARTITION_ACTIVATIONS
-        
-        if PARTITION_ACTIVATIONS:
-            with torch.cuda.stream(transport_stream):
-                inputs = get_full_inputs(ctx.saved_tensors)
-                detached_inputs = detach_variable(inputs)
-        else:
-            inputs = ctx.saved_tensors
-            detached_inputs = detach_variable(inputs)
-
-        # Store the current states.
-        bwd_cpu_rng_state = torch.get_rng_state()
-        bwd_cuda_rng_state = torch.cuda.get_rng_state()
-        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
-
-        # Set the states to what it used to be before the forward pass.
-        torch.set_rng_state(ctx.fwd_cpu_rng_state)
-        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
-        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
-        
-        if PARTITION_ACTIVATIONS:
-            current_stream=torch.cuda.current_stream()
-            current_stream.wait_stream(transport_stream)
-
-        with torch.enable_grad():
-            outputs = ctx.run_function(*detached_inputs)
-
-        # Set the states back to what it was at the start of this function.
-        torch.set_rng_state(bwd_cpu_rng_state)
-        _set_cuda_rng_state(bwd_cuda_rng_state)
-        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
-
-        if isinstance(outputs, torch.Tensor):
-            outputs = (outputs,)
-        torch.autograd.backward(outputs, args)
-        return (None,) + tuple(inp.grad for inp in detached_inputs)
-
-
-def checkpoint(function, *args):
-    """Checkpoint a model or part of the model.
-    This has been directly copied from torch.utils.checkpoint."""
-    return CheckpointFunction.apply(function, *args)
-
-def partition_activations_in_checkpoint(partition_activation):
-    global PARTITION_ACTIVATIONS
-    PARTITION_ACTIVATIONS=partition_activation
-    if dist.get_rank()  == 0:
-        print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
-
-
diff --git a/mpu/transformer.py b/mpu/transformer.py
index 9eab060..099dcb5 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -27,9 +27,7 @@ from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedd
 from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
 
 import deepspeed
-
-from .random import checkpoint
-from .random import get_cuda_rng_tracker
+from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker
 
 from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
 from .utils import split_tensor_along_last_dim
@@ -250,11 +248,11 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm post the self attention.
         layernorm_output = self.post_attention_layernorm(layernorm_input)
         # MLP.
-        mlp_output = self.mlp(layernorm_output)
+        mlp_output = self.mlp(layernorm_output,  *other_tensors)
 
         # Fourth LayerNorm
         if self.sandwich_ln:
-            mlp_output = self.fourth_layernorm(mlp_output, *other_tensors)
+            mlp_output = self.fourth_layernorm(mlp_output)
 
         # Second residual connection.
         output = layernorm_input + mlp_output
@@ -280,10 +278,6 @@ class BaseTransformer(torch.nn.Module):
                  hooks={}
                  ):
         super(BaseTransformer, self).__init__()
-        if deepspeed.checkpointing.is_configured():
-            global get_cuda_rng_tracker, checkpoint
-            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
-            checkpoint = deepspeed.checkpointing.checkpoint
         
         # recording parameters
         self.parallel_output = parallel_output
diff --git a/pretrain_cogview2.py b/pretrain_cogview2.py
new file mode 100755
index 0000000..032fade
--- /dev/null
+++ b/pretrain_cogview2.py
@@ -0,0 +1,170 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   pretrain_gpt2.py
+@Time    :   2021/10/06 00:58:32
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import argparse
+import numpy as np
+
+import mpu
+from arguments import get_args
+from model.cuda2d_model import Cuda2dModel
+from training.deepspeed_training import training_main
+from data_utils import BinaryDataset
+from tokenization import get_tokenizer
+from tokenization.cogview import TextCodeTemplate
+
+def get_masks_and_position_ids(data,
+                            loss_mask=None,
+                            attention_mask=None, args=None):
+    # Extract batch size and sequence length.
+    batch_size, seq_length = data.size()
+
+    # Attention mask (lower triangular).
+    if attention_mask is None:
+        assert loss_mask is not None
+        # loss_mask has n_pad(+1 CLS and [1:] then) zeros, so it is the same as attention_mask, reuse.
+        attention_mask = loss_mask[:, :args.layout[1]].unsqueeze(-2).expand(batch_size, args.layout[1], args.layout[1]).tril()
+        for i in range(batch_size):
+            attention_mask[i].fill_diagonal_(1)
+        attention_mask = attention_mask.unsqueeze(1)
+        
+    # Loss mask.
+    if loss_mask is None:
+        loss_mask = torch.ones(data.size(), dtype=data.dtype, device=data.device)
+
+    # Position ids.
+    assert loss_mask is not None
+    layout = args.layout
+    assert seq_length == layout[-1]
+    n_pads = seq_length - loss_mask.sum(dim=-1).long()
+    position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long,
+                                device=data.device)
+    for i in range(batch_size):
+        torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]], 
+            dtype=torch.long, device=data.device)
+        torch.arange(layout[2] - layout[1], 
+            out=position_ids[i, layout[1]:],
+            dtype=torch.long, device=data.device)
+
+    return attention_mask, loss_mask, position_ids
+
+
+def get_batch(data_iterator, args, timers):
+    # Items and their type.
+    keys = ['text', 'loss_mask']
+    datatype = torch.int64
+
+    # Broadcast data.
+    timers('data loader').start()
+    if data_iterator is not None:
+        data = next(data_iterator)
+    else:
+        data = None
+    timers('data loader').stop()
+
+    data_b = mpu.broadcast_data(keys, data, datatype)
+    # Unpack.
+    tokens_ = data_b['text'].long()
+    loss_mask = data_b['loss_mask'].float()
+
+    labels = tokens_[:, 1:].contiguous()
+    loss_mask = loss_mask[:, 1:].contiguous()
+    tokens = tokens_[:, :-1].contiguous()
+    
+    attention_mask = None        
+
+    # Get the masks and postition ids.
+    attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
+        tokens,
+        loss_mask=loss_mask,
+        attention_mask=attention_mask,
+        args=args
+        )
+    # Convert
+    if args.fp16:
+        attention_mask = attention_mask.half()
+
+    return tokens, labels, loss_mask, attention_mask, position_ids
+
+
+def forward_step(data_iterator, model, args, timers):
+    """Forward step."""
+
+    # Get the batch.
+    timers('batch generator').start()
+    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
+        data_iterator, args, timers)
+    timers('batch generator').stop()
+
+    # split img & txt positions, [PAD] not included # TODO check enough
+    tokenizer = get_tokenizer()
+    img_txt_sep = tokenizer.img_tokenizer.num_tokens
+    img_indices_bool = (tokens < img_txt_sep) & (loss_mask > 0)
+    txt_indices_bool = (~img_indices_bool) & (loss_mask > 0)
+    # Forward model.
+    logits, *mems = model(tokens, position_ids, attention_mask)
+    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels)
+    # scaling loss mask
+    loss_mask[txt_indices_bool] *= args.txt_loss_scale
+    loss_mask = loss_mask.view(-1)  
+
+    losses = losses.view(-1) * loss_mask
+    loss = torch.sum(losses) / loss_mask.sum()
+    # =====================   Log partial losses   ======================== #
+    img_indices_bool2 = img_indices_bool.clone()
+    img_indices_bool2[:, :args.layout[1]] = False
+    img_loss2 = losses[img_indices_bool2.view(-1)].detach().sum() / max(img_indices_bool2.sum(), 1)
+    
+    img_indices_bool = img_indices_bool.view(-1)
+    txt_indices_bool = txt_indices_bool.view(-1)
+    img_loss = losses[img_indices_bool].detach().sum() / max(img_indices_bool.sum(), 1)
+    txt_loss = losses[txt_indices_bool].detach().sum() / max(txt_indices_bool.sum(), 1) / args.txt_loss_scale
+    # ===================== END OF BLOCK ======================= #
+    return loss, {'img_loss': img_loss, 'txt_loss': txt_loss, 'img_loss2': img_loss2}
+    
+def create_dataset_function(path, args):
+    tokenizer = get_tokenizer()
+    layout = [64, 64+16**2, 64+16**2+32**2, 64+64**2+16**2+32**2] # FIXME
+    def process_fn(row):
+        row = row.astype(np.int64)
+        codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))]
+        
+        text = row[:layout[0]]
+        text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1]
+        n_pad = layout[0]-3-len(text)
+        parts = [
+            np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
+            TextCodeTemplate(text, codes[1], tokenizer),
+            *codes[2:]
+        ]
+        ret = np.concatenate(parts, axis=0)
+        return {'text': ret, 
+            'loss_mask':  np.array([0] * (n_pad+1) + [1] * (len(ret) - n_pad - 1)) # don't predict [CLS]
+            }
+    return BinaryDataset(path, process_fn, length_per_sample=layout[-1])
+
+if __name__ == '__main__':    
+    py_parser = argparse.ArgumentParser(add_help=False)
+    
+    py_parser.add_argument('--txt-loss-scale', type=float, default=1)
+    
+    Cuda2dModel.add_model_specific_args(py_parser)
+    
+    known, args_list = py_parser.parse_known_args()
+    
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    
+    args.layout = [int(x) for x in args.layout.split(',')]
+    
+    training_main(args, model_cls=Cuda2dModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py
index a8a0ea2..d5df575 100755
--- a/pretrain_gpt2.py
+++ b/pretrain_gpt2.py
@@ -12,12 +12,13 @@ import sys
 import math
 import random
 import torch
+import argparse
 import numpy as np
 
 import mpu
 from arguments import get_args
 from model.base_model import BaseModel
-from training.deepspeed_training import main
+from training.deepspeed_training import training_main
 from data_utils import BinaryDataset
 from tokenization import get_tokenizer
 from tokenization.cogview import TextCodeTemplate
@@ -32,6 +33,7 @@ def get_masks_and_position_ids(data,
     if attention_mask is None:
         attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
         attention_mask.tril_()
+        attention_mask.unsqueeze_(1)
         
     # Loss mask.
     if loss_mask is None:
@@ -91,7 +93,6 @@ def forward_step(data_iterator, model, args, timers):
     tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
         data_iterator, args, timers)
     timers('batch generator').stop()
-    
     # Forward model.
     logits, *mems = model(tokens, position_ids, attention_mask)
     losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels)
@@ -124,6 +125,11 @@ def create_dataset_function(path, args):
             }
     return BinaryDataset(path, process_fn, length_per_sample=layout[-1])
 
-if __name__ == '__main__':
-    args = get_args()
-    main(args, model_cls=BaseModel, forward_step=forward_step, create_dataset_function=create_dataset_function)
+if __name__ == '__main__':    
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--new_hyperparam', type=str, default=None)
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    
+    training_main(args, model_cls=BaseModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
diff --git a/scripts/ds_config_zero.json b/scripts/ds_config_zero.json
index b985583..f43259b 100755
--- a/scripts/ds_config_zero.json
+++ b/scripts/ds_config_zero.json
@@ -1,5 +1,5 @@
 {
-  "train_micro_batch_size_per_gpu": 2,
+  "train_micro_batch_size_per_gpu": 1,
   "gradient_accumulation_steps": 1,
   "steps_per_print": 1,
   "gradient_clipping": 0.1,
diff --git a/scripts/pretrain_cogview2.sh b/scripts/pretrain_cogview2.sh
new file mode 100755
index 0000000..3bcb93d
--- /dev/null
+++ b/scripts/pretrain_cogview2.sh
@@ -0,0 +1,59 @@
+#! /bin/bash
+
+# Change for multinode config
+
+NUM_WORKERS=1
+NUM_GPUS_PER_WORKER=8
+MP_SIZE=1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+HOST_FILE_PATH="hostfile_single"
+
+full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata"
+
+config_json="$script_dir/ds_config_zero.json"
+gpt_options=" \
+       --experiment-name pretrain-cogview2-test \
+       --tokenizer-type cogview \
+       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
+       --model-parallel-size ${MP_SIZE} \
+       --mode pretrain \
+       --num-layers 48 \
+       --hidden-size 2560 \
+       --num-attention-heads 40 \
+       --train-iters 200000 \
+       --resume-dataloader \
+       --train-data ${full_data} \
+       --split 949,50,1 \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .1 \
+       --checkpoint-activations \
+       --max-sequence-length 1089 \
+       --sandwich-ln \
+       --fp16 \
+       --save-interval 2000 \
+       --eval-interval 1000 \
+       --save $main_dir/checkpoints \
+       --load pretrained/cogview/cogview2-base
+"
+       # --load pretrained/cogview/cogview-base
+
+
+gpt_options="${gpt_options}
+       --deepspeed \
+       --deepspeed_config ${config_json} \
+"
+              
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_cogview2.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh
index f099f80..5a09f77 100755
--- a/scripts/pretrain_multiple_nodes.sh
+++ b/scripts/pretrain_multiple_nodes.sh
@@ -2,7 +2,7 @@
 
 # Change for multinode config
 
-NUM_WORKERS=19
+NUM_WORKERS=1
 NUM_GPUS_PER_WORKER=8
 MP_SIZE=1
 
@@ -10,55 +10,39 @@ script_path=$(realpath $0)
 script_dir=$(dirname $script_path)
 main_dir=$(dirname $script_dir)
 
-# OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0"
 OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
 HOST_FILE_PATH="hostfile"
-# OPTIONS_NCCL=""
-# HOST_FILE_PATH="hostfile_single"
+HOST_FILE_PATH="hostfile_single"
 
-small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata"
 full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata"
 
 config_json="$script_dir/ds_config_zero.json"
 gpt_options=" \
-       --experiment-name cogview-base-long \
-       --img-tokenizer-num-tokens 8192 \
-       --dataset-type CompactBinaryDataset \
+       --experiment-name pretrain-gpt2-cogview-test \
+       --tokenizer-type cogview \
+       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
        --model-parallel-size ${MP_SIZE} \
-       --num-layers 48 \
-       --hidden-size 2560 \
-       --num-attention-heads 40 \
-       --train-iters 300000 \
+       --mode pretrain \
+       --num-layers 12 \
+       --hidden-size 1024 \
+       --num-attention-heads 16 \
+       --train-iters 200000 \
        --resume-dataloader \
-       --train-data ${full_data} \
+       --train-data ${small_data} \
        --split 949,50,1 \
        --distributed-backend nccl \
        --lr-decay-style cosine \
        --warmup .1 \
        --checkpoint-activations \
-       --deepspeed-activation-checkpointing \
-       --max-position-embeddings 1089 \
-       --max-memory-length 0 \
+       --max-sequence-length 1089 \
        --sandwich-ln \
-       --txt-loss-scale 0.1 \
-       --sparse-type cuda_2d \
        --fp16 \
        --save-interval 2000 \
-       --no-load-optim \
-       --no-save-optim \
        --eval-interval 1000 \
-       --save $main_dir/data/checkpoints \
-       --fast-load \
-       --load data/checkpoints/cogview-base \
-       --finetune 
+       --save $main_dir/checkpoints \
 "
-          
-#        --finetune
-       # --save $main_dir/data/checkpoints \
-       #         --restart-iter 199000 
-      
-
-
+       # --load pretrained/cogview/cogview-base
 
 
 gpt_options="${gpt_options}
diff --git a/scripts/cuda_2d_text2image.sh b/scripts_old/cuda_2d_text2image.sh
similarity index 100%
rename from scripts/cuda_2d_text2image.sh
rename to scripts_old/cuda_2d_text2image.sh
diff --git a/scripts/image2text.sh b/scripts_old/image2text.sh
similarity index 100%
rename from scripts/image2text.sh
rename to scripts_old/image2text.sh
diff --git a/scripts/low_level_super_resolution.sh b/scripts_old/low_level_super_resolution.sh
similarity index 100%
rename from scripts/low_level_super_resolution.sh
rename to scripts_old/low_level_super_resolution.sh
diff --git a/scripts/post_selection.sh b/scripts_old/post_selection.sh
similarity index 100%
rename from scripts/post_selection.sh
rename to scripts_old/post_selection.sh
diff --git a/scripts_old/pretrain_multiple_nodes.sh b/scripts_old/pretrain_multiple_nodes.sh
new file mode 100755
index 0000000..f099f80
--- /dev/null
+++ b/scripts_old/pretrain_multiple_nodes.sh
@@ -0,0 +1,74 @@
+#! /bin/bash
+
+# Change for multinode config
+
+NUM_WORKERS=19
+NUM_GPUS_PER_WORKER=8
+MP_SIZE=1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+# OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0"
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+# OPTIONS_NCCL=""
+# HOST_FILE_PATH="hostfile_single"
+
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata"
+full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
+
+config_json="$script_dir/ds_config_zero.json"
+gpt_options=" \
+       --experiment-name cogview-base-long \
+       --img-tokenizer-num-tokens 8192 \
+       --dataset-type CompactBinaryDataset \
+       --model-parallel-size ${MP_SIZE} \
+       --num-layers 48 \
+       --hidden-size 2560 \
+       --num-attention-heads 40 \
+       --train-iters 300000 \
+       --resume-dataloader \
+       --train-data ${full_data} \
+       --split 949,50,1 \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .1 \
+       --checkpoint-activations \
+       --deepspeed-activation-checkpointing \
+       --max-position-embeddings 1089 \
+       --max-memory-length 0 \
+       --sandwich-ln \
+       --txt-loss-scale 0.1 \
+       --sparse-type cuda_2d \
+       --fp16 \
+       --save-interval 2000 \
+       --no-load-optim \
+       --no-save-optim \
+       --eval-interval 1000 \
+       --save $main_dir/data/checkpoints \
+       --fast-load \
+       --load data/checkpoints/cogview-base \
+       --finetune 
+"
+          
+#        --finetune
+       # --save $main_dir/data/checkpoints \
+       #         --restart-iter 199000 
+      
+
+
+
+
+gpt_options="${gpt_options}
+       --deepspeed \
+       --deepspeed_config ${config_json} \
+"
+              
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/scripts/pretrain_single_node.sh b/scripts_old/pretrain_single_node.sh
similarity index 100%
rename from scripts/pretrain_single_node.sh
rename to scripts_old/pretrain_single_node.sh
diff --git a/scripts/super_resolution.sh b/scripts_old/super_resolution.sh
similarity index 100%
rename from scripts/super_resolution.sh
rename to scripts_old/super_resolution.sh
diff --git a/scripts/testnan.sh b/scripts_old/testnan.sh
similarity index 100%
rename from scripts/testnan.sh
rename to scripts_old/testnan.sh
diff --git a/scripts/text2image.sh b/scripts_old/text2image.sh
similarity index 100%
rename from scripts/text2image.sh
rename to scripts_old/text2image.sh
diff --git a/tokenization/cogview/unified_tokenizer.py b/tokenization/cogview/unified_tokenizer.py
index bf66d9a..d72741d 100755
--- a/tokenization/cogview/unified_tokenizer.py
+++ b/tokenization/cogview/unified_tokenizer.py
@@ -35,7 +35,7 @@ class UnifiedTokenizer(object):
             ('[EOI2]', 5),
             ('[EOI3]', 6),
             ('[ROI1]', 7), # Reference
-            ('[ROI2]', 8),
+            ('[ROI2]', 8), # 58200
             ('[ROI3]', 9),
             ('[SEP]', 10),
             ('[MASK]', 11),
diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py
index 5519003..cff164c 100644
--- a/training/deepspeed_training.py
+++ b/training/deepspeed_training.py
@@ -36,11 +36,12 @@ from utils import print_rank_0
 from utils import get_sample_writer
 
 import mpu
-from data_utils import make_loaders, get_tokenizer
+from data_utils import make_loaders 
+from tokenization import get_tokenizer
 
 
 
-def main(args, model_cls, forward_step_function, create_dataset_function, init_function=None):
+def training_main(args, model_cls, forward_step_function, create_dataset_function, init_function=None):
     """Main training program."""
     hooks = {
         'forward_step': forward_step_function,
@@ -50,7 +51,6 @@ def main(args, model_cls, forward_step_function, create_dataset_function, init_f
     
     torch.backends.cuda.matmul.allow_tf32 = False
     torch.backends.cudnn.enabled = False # Disable CuDNN.
-    set_random_seed(args.seed) # Random seeds for reproducability.
     timers = Timers() # Timer.
     
     # Experiment Name
@@ -61,6 +61,8 @@ def main(args, model_cls, forward_step_function, create_dataset_function, init_f
         
     # Pytorch distributed.
     initialize_distributed(args)
+    set_random_seed(args.seed) # Random seeds for reproducability.
+    
     # init tokenizer
     tokenizer = get_tokenizer(args)
     # Data stuff.
@@ -71,7 +73,7 @@ def main(args, model_cls, forward_step_function, create_dataset_function, init_f
 
     # Config model IO
     if args.load is not None:
-        args.iteration = load_checkpoint(model, optimizer, args)
+        args.iteration = load_checkpoint(model, args)
         # if we don't load optim_states, filelock is no more needed.
         # with FileLock("/root/checkpoint_lock", timeout=-1):
         #     args.iteration = load_checkpoint(model, optimizer, args)
@@ -82,7 +84,7 @@ def main(args, model_cls, forward_step_function, create_dataset_function, init_f
     torch.distributed.barrier()
     
     # initialize lr scheduler
-    lr_scheduler = get_learning_rate_scheduler(optimizer, args, args.iteration)
+    lr_scheduler = get_learning_rate_scheduler(optimizer, args.iteration, args)
 
     summary_writer = None
     if torch.distributed.get_rank() == 0:
@@ -217,7 +219,7 @@ def get_optimizer_param_groups(model):
     return param_groups
 
 def get_learning_rate_scheduler(optimizer, iteration, args, 
-                                auto_warmup_steps=50, auto_warmup_rate=0.05):
+                                auto_warmup_steps=100, auto_warmup_rate=0.05):
     """Build the learning rate scheduler."""
 
     # Add linear learning rate scheduler.
@@ -226,10 +228,9 @@ def get_learning_rate_scheduler(optimizer, iteration, args,
     else:
         num_iters = args.train_iters
     num_iters = max(1, num_iters)
-    if args.mode == 'pretrain':
-        init_step = max(iteration-auto_warmup_steps, 0)
-    elif args.mode == 'finetune':
-        init_step = 0
+    init_step = max(iteration-auto_warmup_steps, 0)
+    if args.mode == 'pretrain' and iteration == 0:
+        auto_warmup_steps = 0
     # If init_step <= current_steps <= init_step + auto_warmup_steps,
     # lr = auto_warmup_rate * args.lr.
     # This overrides other rules.
@@ -335,7 +336,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
 
         # Check nan or inf in forward, preventing it from interfering loss scaler,
         # and all reduce metrics by the way
-        loss_checker = lm_loss.detach().item()
+        loss_checker = lm_loss.detach()
         for name in metrics:
             metrics[name] = metrics[name].detach()
             torch.distributed.all_reduce(metrics[name].data)
diff --git a/training/learning_rates.py b/training/learning_rates.py
index fd325c7..ac984f3 100755
--- a/training/learning_rates.py
+++ b/training/learning_rates.py
@@ -47,9 +47,9 @@ class AnnealingLR(_LRScheduler):
             return float(self.start_lr) * self.num_iters / self.warmup_iter
         else:
             if self.decay_style == self.DECAY_STYLES[0]:
-                return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/real_end_iter)
+                return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter)
             elif self.decay_style == self.DECAY_STYLES[1]:
-                decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / real_end_iter)
+                decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / self.end_iter)
                 return self.start_lr / self.decay_ratio * (
                         (math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1)
             elif self.decay_style == self.DECAY_STYLES[2]:
diff --git a/training/model_io.py b/training/model_io.py
index fbc2254..f7527f9 100644
--- a/training/model_io.py
+++ b/training/model_io.py
@@ -108,7 +108,7 @@ def get_checkpoint_iteration(args):
 
     return iteration, release, True
 
-def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=True):
+def load_checkpoint(model, args):
     """Load a model checkpoint."""
 
     iteration, release, success = get_checkpoint_iteration(args)
@@ -137,7 +137,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=
         else: # new params
             assert all(name.find('mixins')>0 for name in missing_keys)
             module.reinit() # initialize mixins
-            model.optimizer.refresh_fp32_params() # restore fp32 weights from module
+    model.optimizer.refresh_fp32_params() # restore fp32 weights from module
 
     # Iterations.
     if args.mode == 'finetune':
diff --git a/utils.py b/utils.py
index 8ae815c..2430bc2 100755
--- a/utils.py
+++ b/utils.py
@@ -21,9 +21,6 @@ import time
 import numpy as np
 import torch
 
-from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
-import mpu
-import model
 from tensorboardX import SummaryWriter
 
 SUMMARY_WRITER_DIR_NAME = 'runs'
@@ -132,311 +129,3 @@ def report_memory(name):
         torch.cuda.memory_reserved() / mega_bytes)
     print_rank_0(string)
 
-
-def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
-    if release:
-        d = 'release'
-    else:
-        d = '{:d}'.format(iteration)
-    if zero:
-        dp_rank = mpu.get_data_parallel_rank()
-        d += '_zero_dp_rank_{}'.format(dp_rank)
-    return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
-
-
-def ensure_directory_exists(filename):
-    dirname = os.path.dirname(filename)
-    if not os.path.exists(dirname):
-        os.makedirs(dirname)
-
-
-def get_checkpoint_tracker_filename(checkpoints_path):
-    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
-
-
-def save_zero_checkpoint(args, iteration, optimizer):
-    zero_sd = {'iteration': iteration,
-               'optimizer_state_dict': optimizer.state_dict()}
-    zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True)
-    ensure_directory_exists(zero_checkpoint_name)
-    torch.save(zero_sd, zero_checkpoint_name)
-    print('  successfully saved {}'.format(zero_checkpoint_name))
-
-
-def save_checkpoint(iteration, model, optimizer,
-                    lr_scheduler, args):
-    """Save a model checkpoint."""
-    if args.deepspeed:
-        save_ds_checkpoint(iteration, model, lr_scheduler, args)
-    else:
-        # Only rank zer0 of the data parallel writes to the disk.
-        if isinstance(model, torchDDP):
-            model = model.module
-
-        if mpu.get_data_parallel_rank() == 0:
-            checkpoint_name = get_checkpoint_name(args.save, iteration)
-            print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
-                  format(torch.distributed.get_rank(), iteration, checkpoint_name))
-
-            sd = {}
-            sd['iteration'] = iteration
-            sd['module'] = model.state_dict()
-
-            # Optimizer stuff.
-            if not args.no_save_optim:
-                if optimizer is not None:
-                    sd['optimizer'] = optimizer.state_dict()
-                if lr_scheduler is not None:
-                    sd['lr_scheduler'] = lr_scheduler.state_dict()
-
-            # rng states.
-            if not args.no_save_rng:
-                sd['random_rng_state'] = random.getstate()
-                sd['np_rng_state'] = np.random.get_state()
-                sd['torch_rng_state'] = torch.get_rng_state()
-                sd['cuda_rng_state'] = torch.cuda.get_rng_state()
-                sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
-
-            ensure_directory_exists(checkpoint_name)
-            torch.save(sd, checkpoint_name)
-            print('  successfully saved {}'.format(checkpoint_name))
-
-    # Wait so everyone is done (necessary)
-    torch.distributed.barrier()
-    # And update the latest iteration
-    if torch.distributed.get_rank() == 0:
-        tracker_filename = get_checkpoint_tracker_filename(args.save)
-        with open(tracker_filename, 'w') as f:
-            f.write(str(iteration))
-    # Wait so everyone is done (not necessary)
-    torch.distributed.barrier()
-
-
-def save_ds_checkpoint(iteration, model, lr_scheduler, args):
-    """Save a model checkpoint."""
-
-    sd = {}
-    sd['iteration'] = iteration
-    if lr_scheduler is not None:
-        sd['client_lr_scheduler'] = lr_scheduler.state_dict()
-    # rng states.
-    if not args.no_save_rng:
-        sd['random_rng_state'] = random.getstate()
-        sd['np_rng_state'] = np.random.get_state()
-        sd['torch_rng_state'] = torch.get_rng_state()
-        sd['cuda_rng_state'] = torch.cuda.get_rng_state()
-        sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
-    if args.no_save_optim:
-        save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd)
-    else:
-        model.save_checkpoint(args.save, str(iteration), client_state=sd)
-
-def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True):
-    
-    os.makedirs(save_dir, exist_ok=True)
-
-    if tag is None:
-        tag = f"global_step{model.global_steps}"
-
-    # Ensure tag is a string
-    tag = str(tag)
-
-    # Ensure checkpoint tag is consistent across ranks
-    model._checkpoint_tag_validation(tag)
-
-    if model.save_non_zero_checkpoint:
-        model._create_checkpoint_file(save_dir, tag, False)
-        model._save_checkpoint(save_dir, tag, client_state=client_state)
-
-    # Save latest checkpoint tag
-    if save_latest:
-        with open(os.path.join(save_dir, 'latest'), 'w') as fd:
-            fd.write(tag)
-
-    return True
-
-
-def get_checkpoint_iteration(args):
-    # Read the tracker file and set the iteration.
-    tracker_filename = get_checkpoint_tracker_filename(args.load)
-    if not os.path.isfile(tracker_filename):
-        print_rank_0('WARNING: could not find the metadata file {} '.format(
-            tracker_filename))
-        print_rank_0('    will not load any checkpoints and will start from '
-                     'random')
-        return 0, False, False
-    iteration = 0
-    release = False
-    with open(tracker_filename, 'r') as f:
-        metastring = f.read().strip()
-        try:
-            iteration = int(metastring)
-        except ValueError:
-            release = metastring == 'release'
-            if not release:
-                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
-                    tracker_filename))
-                exit()
-
-    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
-        tracker_filename)
-
-    return iteration, release, True
-
-
-def extend_position_embedding(weight, length):
-    ori_length, hidden_size = weight.shape
-    assert length % ori_length == 0
-    position_embeddings = weight.expand(length // ori_length, -1, -1).reshape(length, hidden_size)
-    return position_embeddings
-
-def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=True):
-    """Load a model checkpoint."""
-
-    iteration, release, success = get_checkpoint_iteration(args)
-
-    if not success:
-        return 0
-
-    if args.deepspeed:
-        
-        checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim, load_module_strict=not args.finetune)
-        # if args.finetune:
-        #     model.module.module.init_plus_from_old()
-        if (args.finetune or args.no_load_optim) and model.zero_optimization():
-            model.optimizer.refresh_fp32_params()
-        if "client_lr_scheduler" in sd and not args.finetune:
-            lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
-            print_rank_0("Load lr scheduler state")
-        if checkpoint_name is None:
-            if mpu.get_data_parallel_rank() == 0:
-                print("Unable to load checkpoint.")
-            return iteration
-
-    else:
-
-        # Checkpoint.
-        checkpoint_name = get_checkpoint_name(args.load, iteration, release)
-
-        if mpu.get_data_parallel_rank() == 0:
-            print('global rank {} is loading checkpoint {}'.format(
-                torch.distributed.get_rank(), checkpoint_name))
-
-        # Load the checkpoint.
-        sd = torch.load(checkpoint_name, map_location='cpu')
-
-        if isinstance(model, torchDDP):
-            model = model.module
-
-        # Model.
-        try:
-            model.load_state_dict(sd['module'])
-        except KeyError:
-            print_rank_0('A metadata file exists but unable to load model '
-                         'from checkpoint {}, exiting'.format(checkpoint_name))
-            exit()
-
-        # Optimizer.
-        if not release and not args.finetune and not args.no_load_optim:
-            try:
-                if optimizer is not None and load_optimizer_states:
-                    optimizer.load_state_dict(sd['optimizer'])
-                if lr_scheduler is not None:
-                    lr_scheduler.load_state_dict(sd['lr_scheduler'])
-            except KeyError:
-                print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
-                             'Specify --no-load-optim or --finetune to prevent '
-                             'attempting to load the optimizer '
-                             'state.'.format(checkpoint_name))
-                exit()
-
-    # Iterations.
-    if args.finetune or release:
-        iteration = 0
-    else:
-        try:
-            iteration = sd['iteration']
-        except KeyError:
-            try:  # Backward compatible with older checkpoints
-                iteration = sd['total_iters']
-            except KeyError:
-                print_rank_0('A metadata file exists but Unable to load iteration '
-                             ' from checkpoint {}, exiting'.format(checkpoint_name))
-                exit()
-
-    # rng states.
-    if not release and not args.finetune and not args.no_load_rng:
-        try:
-            random.setstate(sd['random_rng_state'])
-            np.random.set_state(sd['np_rng_state'])
-            torch.set_rng_state(sd['torch_rng_state'])
-            torch.cuda.set_rng_state(sd['cuda_rng_state'])
-            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
-        except KeyError:
-            print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
-                         'Specify --no-load-rng or --finetune to prevent '
-                         'attempting to load the random '
-                         'state.'.format(checkpoint_name))
-            exit()
-
-    if mpu.get_data_parallel_rank() == 0:
-        print('  successfully loaded {}'.format(checkpoint_name))
-    del sd
-    return iteration
-
-
-def load_weights(src, dst, dst2src=False):
-    """
-    Loads weights from src to dst via in place copy.
-    src is a huggingface gpt2model, while dst is one of our models.
-    dst2src=True loads parameters from our models into huggingface's.
-    ^dst2src is still untested
-    """
-    conv_layer = 'Conv1D' in str(type(src))
-    for n, p in src.named_parameters():
-        if dst2src:
-            data = dst._parameters[n].data
-            load = p.data
-        else:
-            data = p.data
-            load = dst._parameters[n].data
-        if conv_layer and 'weight' in n:
-            data = data.t().contiguous()
-        load.copy_(data)
-
-
-#        dst._parameters[n].data.copy_(data)
-
-def load_mlp(our, oai, dst2src=False):
-    load_weights(oai.c_fc, our.dense_h_to_4h, dst2src)
-    load_weights(oai.c_proj, our.dense_4h_to_h, dst2src)
-
-
-def load_attention(our, oai, dst2src=False):
-    load_weights(oai.c_attn, our.query_key_value, dst2src)
-    load_weights(oai.c_proj, our.dense, dst2src)
-
-
-def load_transformer_layer(our, oai, dst2src=False):
-    load_weights(oai.ln_1, our.input_layernorm, dst2src)
-    load_weights(oai.ln_2, our.post_attention_layernorm, dst2src)
-    load_mlp(our.mlp, oai.mlp, dst2src)
-    load_attention(our.attention, oai.attn, dst2src)
-
-
-def move_weights(our, oai, dst2src=False):
-    """
-    Loads weights from `oai` to `our` via in place copy.
-    `oai` is a huggingface gpt2model, while `our` is one of our models.
-    dst2src=True loads parameters from our models into huggingface's.
-    ^dst2src=True is still untested
-    """
-    #    while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)):
-    #        our=our.module
-    transformer_model = oai.transformer
-    load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src)
-    load_weights(transformer_model.wte, our.word_embeddings, dst2src)
-    load_weights(transformer_model.wpe, our.position_embeddings, dst2src)
-
-    for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
-        load_transformer_layer(our_layer, oai_layer, dst2src)
\ No newline at end of file
-- 
GitLab