From 5db7dfe4ef713bbad3e6c486bbb6c6ae1df9f7c2 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Tue, 5 Oct 2021 17:57:48 +0000
Subject: [PATCH] tmp save3

---
 model/base_model.py                  |   3 +-
 model/cached_autoregressive_model.py |   4 +-
 model/cuda2d_model.py                | 153 ++++++
 model/mixins.py                      |   4 +-
 mpu/local_attention_function.py      |  79 ---
 mpu/transformer.py                   |  33 +-
 pretrain_gpt2.py                     | 775 +--------------------------
 training/deepspeed_training.py       | 583 ++++++++++++++++++++
 training/learning_rates.py           |  81 +++
 training/model_io.py                 | 162 ++++++
 10 files changed, 1015 insertions(+), 862 deletions(-)
 create mode 100644 model/cuda2d_model.py
 create mode 100644 training/deepspeed_training.py
 create mode 100755 training/learning_rates.py
 create mode 100644 training/model_io.py

diff --git a/model/base_model.py b/model/base_model.py
index 8e36ef7..7b7769c 100644
--- a/model/base_model.py
+++ b/model/base_model.py
@@ -41,7 +41,8 @@ class BaseModel(torch.nn.Module):
         self.mixins = torch.nn.ModuleList()
         
     def reinit(self):
-        for m in self.mixins:
+        # if some mixins are loaded, overrides this function
+        for m in self.mixins: 
             m.reinit(self.transformer)
     
     def forward(self, *args, **kwargs):
diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py
index e7234cb..ec2dd60 100755
--- a/model/cached_autoregressive_model.py
+++ b/model/cached_autoregressive_model.py
@@ -1,7 +1,7 @@
 # -*- encoding: utf-8 -*-
 '''
-@File    :   gpt2_modeling.py
-@Time    :   2021/10/02 00:37:22
+@File    :   cached_autoregressive_model.py
+@Time    :   2021/10/02 01:36:24
 @Author  :   Ming Ding 
 @Contact :   dm18@mail.tsinghua.edu.cn
 '''
diff --git a/model/cuda2d_model.py b/model/cuda2d_model.py
new file mode 100644
index 0000000..a9e1175
--- /dev/null
+++ b/model/cuda2d_model.py
@@ -0,0 +1,153 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   cuda2d_model.py
+@Time    :   2021/10/02 01:36:32
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import torch.nn.functional as F
+
+
+from .base_model import BaseModel
+from .mixins import PositionEmbeddingMixin, AttentionMixin
+
+from mpu.transformer import split_tensor_along_last_dim
+from mpu.local_attention_function import f_similar, f_weighting
+from mpu.random import get_cuda_rng_tracker
+from mpu.utils import sqrt
+
+
+class Cuda2dModel(BaseModel):
+    def __init__(self, args, transformer=None):
+        super().__init__(args, transformer=transformer)
+        additional_seqlen = args.new_sequence_length - args.max_sequence_length
+        self.mixins.append(PositionEmbeddingMixin(
+            additional_seqlen, args.hidden_size
+        ))
+        self.mixins.append(AttentionMixin(
+            num_layers=args.num_layers,
+            hidden_size=args.hidden_size
+        ))
+        self.layout = args.layout
+        # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
+        self.kernel_size = args.kernel_size
+        self.kernel_size2 = args.kernel_size2
+        self.log_attention_weights = None
+    
+    def position_embedding_forward(self, position_ids, *other_tensors):
+        position = position_ids[..., :self.layout[1]]
+        position_plus = position_ids[..., self.layout[1]:]
+        position_embeddings = torch.cat(
+                (
+                    self.transformer.position_embeddings(position),
+                    self.mixins[0].position_embeddings(position_plus)
+                ),
+                dim=-2
+            )
+        return position_embeddings
+        
+    def attention_forward(self, hidden_states, mask, *other_tensors, layer_id=None):
+        attn_module = self.transformer.layers[layer_id].attention
+        # attention_plus on all layers
+        query_key_value_plus = self.mixins[1].query_key_value[layer_id] 
+        dense_plus = self.mixins[1].dense[layer_id]
+        
+        # split two parts
+        hidden_states_plus = hidden_states[:, self.layout[1]:]
+        hidden_states = hidden_states[:, :self.layout[1]]
+        # base model qkv
+        mixed_raw_layer = attn_module.query_key_value(hidden_states)
+        q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
+        # cuda2d model qkv
+        mixed_raw_layer = query_key_value_plus(hidden_states_plus)
+        q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3)
+        
+        dropout_fn = attn_module.attention_dropout if self.training else None
+
+        # cuda2d attention
+        context_layer0, context_layer1 = sparse_attention_2d_light(
+                q0, k0, v0,
+                q1, k1, v1,
+                mask,
+                n_head=attn_module.num_attention_heads_per_partition,
+                text_len=self.layout[0],
+                kernel_size=self.kernel_size,
+                kernel_size2=self.kernel_size2,
+                attention_dropout=dropout_fn,
+                log_attention_weights=self.log_attention_weights
+            )
+
+        output_0 = attn_module.dense(context_layer0)
+        output_1 = dense_plus(context_layer1)
+        output = torch.cat((output_0, output_1), dim=1)
+        
+        return output
+
+def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, **kwargs):
+    '''
+    q0, k0, v0: [batch_size, 1088, hidden_size]
+    q1, k1, v1: [batch_size, 4096, h2]
+    n_head: int
+    attention_mask: [batch_size, 1088, 1088]
+    '''
+    b, s0, h0 = q0.shape
+    b, s1, h1 = q1.shape
+    h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1)
+
+    q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
+    v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
+    k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
+    
+    # standard attention for level 0
+    attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
+    
+    if log_attention_weights is not None:
+        attention_scores += log_attention_weights
+
+    attention_scores = torch.mul(attention_scores, attention_mask) - \
+                    10000.0 * (1.0 - attention_mask)
+    
+    attention_probs0 = F.softmax(attention_scores, dim=-1)
+    
+    # local attention for level 1
+    q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
+    k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
+    v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
+    scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    
+
+    # cross attention
+    k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
+    scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
+    scores_1 = torch.cat(
+        (
+            scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]),
+            scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
+        ),
+        dim=-1)
+    attention_probs1 = F.softmax(scores_1, dim=-1)
+
+    if attention_dropout is not None:
+        with get_cuda_rng_tracker().fork():
+            attention_probs0 = attention_dropout(attention_probs0)
+            attention_probs1 = attention_dropout(attention_probs1)
+        
+    # weighting for level 0
+    context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
+    # weighting for level 1
+    probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
+    context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
+    context1 = context1_to_1.view(b, n_head * h, l1**2)
+    # weighting for cross attention
+    probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
+    v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
+    context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
+    context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
+    context1 = context1 + context1_to_0
+    return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
\ No newline at end of file
diff --git a/model/mixins.py b/model/mixins.py
index 4e195e3..0d304cf 100644
--- a/model/mixins.py
+++ b/model/mixins.py
@@ -25,12 +25,12 @@ class BaseMixin(torch.nn.Module):
         pass
 
 class PositionEmbeddingMixin(BaseMixin):
-    def __init__(self, new_sequence_length, hidden_size, 
+    def __init__(self, additional_sequence_length, hidden_size, 
                 init_method_std=0.02, reinit_slice=(-1024, None)
         ):
         super(PositionEmbeddingMixin, self).__init__()
         self.reinit_slice = reinit_slice
-        self.position_embeddings = torch.nn.Embedding(new_sequence_length, hidden_size)
+        self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
     def reinit(self, transformer, *pre_mixins):
         old_weights = transformer.position_embeddings.weight.data[self.reinit_slice]
diff --git a/mpu/local_attention_function.py b/mpu/local_attention_function.py
index 1a5af91..4147e01 100644
--- a/mpu/local_attention_function.py
+++ b/mpu/local_attention_function.py
@@ -62,82 +62,3 @@ f_similar = similarFunction.apply
 f_weighting = weightingFunction.apply
 
 
-class LocalAttention(nn.Module):
-    def __init__(self, inp_channels, out_channels, kH, kW):
-        super(LocalAttention, self).__init__()
-        self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
-        self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
-        self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
-        self.kH = kH
-        self.kW = kW
-
-    def forward(self, x):
-        x1 = self.conv1(x)
-        x2 = self.conv2(x)
-        x3 = self.conv3(x)
-
-        weight = f_similar(x1, x2, self.kH, self.kW)
-        weight = F.softmax(weight, -1)
-        out = f_weighting(x3, weight, self.kH, self.kW)
-
-        return out
-
-
-class TorchLocalAttention(nn.Module):
-    def __init__(self, inp_channels, out_channels, kH, kW):
-        super(TorchLocalAttention, self).__init__()
-        self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
-        self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
-        self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
-        self.kH = kH
-        self.kW = kW
-
-    @staticmethod
-    def f_similar(x_theta, x_phi, kh, kw, casual_mask=False):
-        n, c, h, w = x_theta.size()  # (N, inter_channels, H, W)
-        pad = (kh // 2, kw // 2)
-        x_theta = x_theta.permute(0, 2, 3, 1).contiguous()
-        x_theta = x_theta.view(n * h * w, 1, c)
-
-        x_phi = F.unfold(x_phi, kernel_size=(kh, kw), stride=1, padding=pad)
-        x_phi = x_phi.contiguous().view(n, c, kh * kw, h * w)
-        x_phi = x_phi.permute(0, 3, 1, 2).contiguous()
-        x_phi = x_phi.view(n * h * w, c, kh * kw)
-        out = x_theta @ x_phi
-        out = out.view(n, h, w, kh * kw)
-        if casual_mask:
-            out = out[..., :kh * kw // 2 + 1]
-        return out
-
-    @staticmethod
-    def f_weighting(x_theta, x_phi, kh, kw, casual_mask=False):
-        n, c, h, w = x_theta.size()  # (N, inter_channels, H, W)
-        pad = (kh // 2, kw // 2)
-        x_theta = F.unfold(x_theta, kernel_size=(kh, kw), stride=1, padding=pad)
-        x_theta = x_theta.permute(0, 2, 1).contiguous()
-        x_theta = x_theta.view(n * h * w, c, kh * kw)
-
-        if casual_mask:
-            x_theta = x_theta[..., :kh * kw // 2 + 1]
-            x_phi = x_phi.view(n * h * w, kh * kw // 2 + 1, 1)
-        else:   
-            x_phi = x_phi.view(n * h * w, kh * kw, 1)
-
-        out = torch.matmul(x_theta, x_phi)
-        out = out.squeeze(-1)
-        out = out.view(n, h, w, c)
-        out = out.permute(0, 3, 1, 2).contiguous()
-
-        return out
-
-    def forward(self, x):
-        x1 = self.conv1(x)
-        x2 = self.conv2(x)
-        x3 = self.conv3(x)
-
-        weight = self.f_similar(x1, x2, self.kH, self.kW)
-        weight = F.softmax(weight, -1)
-        out = self.f_weighting(x3, weight, self.kH, self.kW)
-
-        return out
-    
diff --git a/mpu/transformer.py b/mpu/transformer.py
index c19ef2a..9eab060 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -52,13 +52,13 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
         query_layer / math.sqrt(query_layer.shape[-1]),
         key_layer.transpose(-1, -2)
     )
+    if log_attention_weights is not None:
+        attention_scores += log_attention_weights
     
     if attention_mask.shape[-2] > 1: # if auto-regressive, skip
         attention_scores = torch.mul(attention_scores, attention_mask) - \
                     10000.0 * (1.0 - attention_mask)
-    if log_attention_weights is not None:
-        attention_scores += log_attention_weights
-    
+
     attention_probs = F.softmax(attention_scores, dim=-1)
 
     if attention_dropout is not None:
@@ -179,33 +179,6 @@ class MLP(torch.nn.Module):
 
 
 class BaseTransformerLayer(torch.nn.Module):
-    """A single layer transformer for GPT2.
-
-    We use the following notation:
-        h: hidden size
-        n: number of attention heads
-        b: batch size
-        s: sequence length
-    Transformore layer takes input with size [b, s, h] and returns an
-    output of the same size.
-
-    Arguments:
-        hidden_size: The hidden size of the self attention.
-        num_attention_heads: number of attention head in the self
-                             attention.
-        attention_dropout_prob: dropout probability of the attention
-                                score in self attention.
-        output_dropout_prob: dropout probability for the outputs
-                             after self attention and final output.
-        layernorm_epsilon: epsilon used in layernorm to avoid
-                           division by zero.
-        init_method: initialization method used for the weights. Note
-                     that all biases are initialized to zero and
-                     layernorm weight are initialized to one.
-        output_layer_init_method: output layers (attention output and
-                                  mlp output) initialization. If None,
-                                  use `init_method`.
-    """
     def __init__(
         self,
         hidden_size,
diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py
index c8ac52a..3b8ccc4 100755
--- a/pretrain_gpt2.py
+++ b/pretrain_gpt2.py
@@ -1,232 +1,42 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# -*- encoding: utf-8 -*-
+'''
+@File    :   pretrain_gpt2.py
+@Time    :   2021/10/06 00:58:32
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
 
-from datetime import datetime
+# here put the import lib
 import os
-import random
+import sys
 import math
-from filelock import FileLock
-import numpy as np
+import random
 import torch
 
-import deepspeed
-from contextlib import ExitStack
-from arguments import get_args
-from learning_rates import AnnealingLR
-
 import mpu
-from mpu import GPT2ParallelTransformer
-from utils import Timers
-from utils import save_checkpoint
-from utils import load_checkpoint
-from utils import report_memory
-from utils import print_args
-from utils import print_rank_0
-from utils import get_sample_writer
-import torch.distributed as dist
-
-from data_utils import make_loaders, get_tokenizer, detect_new_datasets
-
-import stat
-
-def get_model(args, sparse_config=None):
-    """Build the model."""
-
-    print_rank_0('building CogView2 model ...')
-    ml = args.max_position_embeddings
-    model = GPT2ParallelTransformer(num_layers=args.num_layers,
-                      vocab_size=args.vocab_size,
-                      hidden_size=args.hidden_size,
-                      num_attention_heads=args.num_attention_heads,
-                      embedding_dropout_prob=args.hidden_dropout,
-                      attention_dropout_prob=args.attention_dropout,
-                      output_dropout_prob=args.hidden_dropout,
-                      max_sequence_length=ml,
-                      max_memory_length=args.max_memory_length,
-                      checkpoint_activations=args.checkpoint_activations,
-                      checkpoint_num_layers=args.checkpoint_num_layers,
-                      parallel_output=True,
-                      sparse_config=sparse_config if sparse_config is not None else args.sparse_config,
-                      sandwich_ln=args.sandwich_ln,
-                      finetune=args.finetune
-                      )
-
-    if mpu.get_data_parallel_rank() == 0:
-        print(' > number of parameters on model parallel rank {}: {}'.format(
-            mpu.get_model_parallel_rank(),
-            sum([p.nelement() for p in model.parameters()])), flush=True)
-
-    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
-    if hasattr(args, "deepspeed") and args.deepspeed and args.fp16:
-        model.half()
-
-    # GPU allocation.
-    model.cuda(torch.cuda.current_device())
-
-    # Fp16 conversion.
-    # if args.fp16:
-    #     model = FP16_Module(model)
-
-    # Wrap model for distributed training.
-    # if not args.deepspeed:
-    #     if USE_TORCH_DDP:
-    #         i = torch.cuda.current_device()
-    #         model = DDP(model, device_ids=[i], output_device=i,
-    #                     process_group=mpu.get_data_parallel_group())
-    #     else:
-    #         model = DDP(model)
-
-    return model
-
-
-def gpt2_get_params_for_weight_decay_optimization(module):
-    
-    weight_decay_params = {'params': []}
-    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
-    for module_ in module.modules():
-        if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
-            no_weight_decay_params['params'].extend(
-                [p for p in list(module_._parameters.values())
-                 if p is not None and p.requires_grad])
-        else:
-            weight_decay_params['params'].extend(
-                [p for n, p in list(module_._parameters.items())
-                 if p is not None and n != 'bias' and p.requires_grad])
-            no_weight_decay_params['params'].extend(
-                [p for n, p in list(module_._parameters.items())
-                 if p is not None and n == 'bias' and p.requires_grad])
-    return weight_decay_params, no_weight_decay_params
-
-
-def get_optimizer_param_groups(model):
-    # Build parameter groups (weight decay and non-decay).
-    while hasattr(model, 'module'):
-        print(model)
-        model = model.module
-        
-    param_groups = gpt2_get_params_for_weight_decay_optimization(model) # TODO move to here
-
-    # Add model parallel attribute if it is not set.
-    for param_group in param_groups:
-        for param in param_group['params']:
-            if not hasattr(param, 'model_parallel'):
-                param.model_parallel = False
-
-    return param_groups
-
-def get_learning_rate_scheduler(optimizer, args):
-    """Build the learning rate scheduler."""
-
-    # Add linear learning rate scheduler.
-    if args.lr_decay_iters is not None:
-        num_iters = args.lr_decay_iters
-    else:
-        num_iters = args.train_iters
-    num_iters = max(1, num_iters - args.restart_iter)
-    init_step = -1
-    warmup_iter = args.warmup * num_iters
-    lr_scheduler = AnnealingLR(optimizer,
-                               start_lr=args.lr,
-                               warmup_iter=warmup_iter,
-                               num_iters=num_iters,
-                               decay_style=args.lr_decay_style,
-                               last_iter=init_step,
-                               decay_ratio=args.lr_decay_ratio,
-                               restart_iter=args.restart_iter
-                               )
-
-    return lr_scheduler
-
-def setup_model_and_optimizer(args):
-    """Setup model and optimizer."""
-
-    model = get_model(args)
-
-    if args.finetune: # TODO
-        model.requires_grad_(False)
-        for name, param in model.named_parameters():
-            # if name.find('_plus') > 0:
-            if name.find('query_key_value') >= 0 or name.find('attention.dense') >= 0 or name.find('position_embeddings') >= 0:
-                param.requires_grad_(True)
-
-    param_groups = get_optimizer_param_groups(model)
-
-    if args.train_data is not None:
-        if args.deepspeed:
-            print_rank_0("DeepSpeed is enabled.")
-            model, optimizer, _, _ = deepspeed.initialize(
-                model=model,
-                model_parameters=param_groups,
-                args=args,
-                mpu=mpu,
-                dist_init_required=False
-            )
-        else:
-            raise ValueError('Currently, we only support training with deepspeed.')
-        lr_scheduler = get_learning_rate_scheduler(optimizer, args)
-    else:
-        optimizer, lr_scheduler = None, None
-
-    return model, optimizer, lr_scheduler
-
+from arguments import get_args
+from model.base_model import BaseModel
+from training.deepspeed_training import main
 
 def get_masks_and_position_ids(data,
                             loss_mask=None,
                             attention_mask=None, args=None):
-    assert args is not None
     # Extract batch size and sequence length.
     batch_size, seq_length = data.size()
 
     # Attention mask (lower triangular).
     if attention_mask is None:
-        if args.sparse_config.sparse_type == 'cuda_2d':
-            # single direction, [PAD]s are at the start of the seq.
-            assert loss_mask is not None
-            # loss_mask has n_pad(+1 CLS and [1:] then) zeros, so it is the same as attention_mask, reuse.
-            attention_mask = loss_mask[:, :args.layout[1]].unsqueeze(-2).expand(batch_size, args.layout[1], args.layout[1]).tril()
-            for i in range(batch_size):
-                attention_mask[i].fill_diagonal_(1)
-            attention_mask = attention_mask.unsqueeze(1)
-        elif args.sparse_config.sparse_type == 'standard':
-            attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
-            attention_mask.tril_()
-        else:
-            raise NotImplementedError
-
+        attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
+        attention_mask.tril_()
+        
     # Loss mask.
     if loss_mask is None:
         loss_mask = torch.ones(data.size(), dtype=data.dtype, device=data.device)
 
     # Position ids.
-    if args.sparse_config.sparse_type == 'cuda_2d':
-        assert loss_mask is not None
-        layout = args.layout
-        assert seq_length == layout[-1]
-        n_pads = seq_length - loss_mask.sum(dim=-1).long()
-        position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long,
-                                    device=data.device)
-        for i in range(batch_size):
-            torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]], 
-                dtype=torch.long, device=data.device)
-            torch.arange(layout[2] - layout[1], 
-                out=position_ids[i, layout[1]:],
-                dtype=torch.long, device=data.device)
-    else:
-        position_ids = torch.arange(seq_length, dtype=torch.long,
-                                    device=data.device)
-        position_ids = position_ids.unsqueeze(0).expand_as(data)
+    position_ids = torch.arange(seq_length, dtype=torch.long,
+                                device=data.device)
+    position_ids = position_ids.unsqueeze(0).expand_as(data)
 
     return attention_mask, loss_mask, position_ids
 
@@ -269,7 +79,7 @@ def get_batch(data_iterator, args, timers):
     return tokens, labels, loss_mask, attention_mask, position_ids
 
 
-def forward_step(data_iterator, model, args, timers, mems):
+def forward_step(data_iterator, model, args, timers):
     """Forward step."""
 
     # Get the batch.
@@ -277,549 +87,18 @@ def forward_step(data_iterator, model, args, timers, mems):
     tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
         data_iterator, args, timers)
     timers('batch generator').stop()
-
-    # split img & txt positions, [PAD] not included # TODO check enough
-    tokenizer = get_tokenizer()
-    img_txt_sep = tokenizer.img_tokenizer.num_tokens
-    img_indices_bool = (tokens.detach() < img_txt_sep) & (loss_mask > 0)
-    txt_indices_bool = (~img_indices_bool) & (loss_mask > 0)
+    
     # Forward model.
-    logits, *mems = model(tokens, position_ids, attention_mask, *mems)
-    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
-                                              labels)
+    logits, *mems = model(tokens, position_ids, attention_mask)
+    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels)
     # scaling loss mask
-    loss_mask[txt_indices_bool] *= args.txt_loss_scale
     loss_mask = loss_mask.view(-1)  
 
     losses = losses.view(-1) * loss_mask
     loss = torch.sum(losses) / loss_mask.sum()
-    # =====================   Log partial losses   ======================== #
-    if args.sparse_config.sparse_type == 'cuda_2d':
-        img_indices_bool2 = img_indices_bool.clone()
-        img_indices_bool2[:, :args.sparse_config.layout[1]] = False
-        img_loss2 = losses[img_indices_bool2.view(-1)].detach().sum() / max(img_indices_bool2.sum(), 1)
-        torch.distributed.all_reduce(img_loss2.data)
-        img_loss2.data = img_loss2.data / args.world_size
-        img_indices_bool[:, args.sparse_config.layout[1]:] = False
-    else:
-        img_loss2 = 0
-    img_indices_bool = img_indices_bool.view(-1)
-    txt_indices_bool = txt_indices_bool.view(-1)
-    img_loss = losses[img_indices_bool].detach().sum() / max(img_indices_bool.sum(), 1)
-    txt_loss = losses[txt_indices_bool].detach().sum() / max(txt_indices_bool.sum(), 1) / args.txt_loss_scale
-
-    # Reduce losses for logging
-    torch.distributed.all_reduce(img_loss.data)
-    torch.distributed.all_reduce(txt_loss.data)
-    img_loss.data = img_loss.data / args.world_size
-    txt_loss.data = txt_loss.data / args.world_size
-
-    # ===================== END OF BLOCK ======================= #
-    return loss, mems, img_loss, txt_loss, img_loss2
-
-
-def backward_step(optimizer, model, lm_loss, args, timers):
-    """Backward step."""
-
-    # Total loss.
-    loss = lm_loss
-
-    # Backward pass.
-    if args.deepspeed:
-        model.backward(loss)
-    else:
-        optimizer.zero_grad()
-        if args.fp16:
-            optimizer.backward(loss, update_master_grads=False)
-        else:
-            loss.backward()
-
-    reduced_losses = lm_loss.view(1)
-
-    # Reduce losses for logging
-    torch.distributed.all_reduce(reduced_losses.data)
-    reduced_losses.data = reduced_losses.data / args.world_size
-
-    if args.deepspeed:
-        # DeepSpeed backward propagation already addressed all reduce communication.
-        # Reset the timer to avoid breaking timer logs below.
-        timers('allreduce').reset()
-    # else:
-    #     if not USE_TORCH_DDP:
-    #         timers('allreduce').start()
-    #         model.allreduce_params(reduce_after=False,
-    #                                fp32_allreduce=args.fp32_allreduce)
-    #         timers('allreduce').stop()
-
-    lm_loss_reduced = reduced_losses
-
-    # Update master gradients.
-    # if not args.deepspeed:
-    #     if args.fp16:
-    #         optimizer.update_master_grads()
-
-    #     # Clipping gradients helps prevent the exploding gradient.
-    #     if args.clip_grad > 0:
-    #         if not args.fp16:
-    #             mpu.clip_grad_norm(model.parameters(), args.clip_grad)
-    #         else:
-    #             optimizer.clip_master_grads(args.clip_grad)
-
-    return lm_loss_reduced
-
-
-def see_memory_usage(message, force=False):
-    if not force:
-        return
-    dist.barrier()
-    if dist.get_rank() == 0:
-        print(message)
-        print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes")
-        print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes")
-        print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes")
-        print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes")
-        print(" ")
-
-def train_step(data_iterator, model, optimizer, lr_scheduler,
-               args, timers, mems):
-    """Single training step."""
-    while True:
-        # Forward model for one step.
-        timers('forward').start()
-        lm_loss, mems, img_loss, txt_loss, img_loss2 = forward_step(data_iterator, model, args, timers, mems)
-        timers('forward').stop()
-
-        if (img_loss + txt_loss).isnan().any() or (img_loss + txt_loss).isinf().any():
-            print('Skipping backward and optimizer step for nan or inf in forwarding!')
-            return (img_loss + txt_loss), 1, mems, img_loss, txt_loss, img_loss2
-
-        # Calculate gradients, reduce across processes, and clip.
-        timers('backward').start()
-        lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers)
-        timers('backward').stop()
-        # Update parameters.
-        skipped_iter, complete = 0, False
-        timers('optimizer').start()
-        if args.deepspeed:
-            if model.is_gradient_accumulation_boundary():
-                model.step()
-                complete = True
-                if not (args.fp16 and optimizer.overflow):
-                    lr_scheduler.step()
-                else:
-                    skipped_iter = 1
-            else:
-                model.step()
-        else:
-            optimizer.step()
-            complete = True
-            # Update learning rate.
-            if not (args.fp16 and optimizer.overflow):
-                lr_scheduler.step()
-            else:
-                skipped_iter = 1
-        timers('optimizer').stop()
-        if complete:
-            break
-    return lm_loss_reduced, skipped_iter, mems, img_loss, txt_loss, img_loss2
-
-
-def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, img_loss, txt_loss, img_loss2):
-    log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step)
-    log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time)
-    log_string += ' learning rate {:.3E} |'.format(lr)
-    log_string += ' lm loss {:.6E} |'.format(loss)
-    log_string += ' img loss {:.6E} |'.format(img_loss)
-    if args.sparse_config.sparse_type == 'cuda_2d':
-        log_string += ' img loss2 {:.6E} |'.format(img_loss2)
-    log_string += ' unscaled txt loss {:.6E} |'.format(txt_loss)
-    if args.fp16:
-        log_string += ' loss scale {:.1f} |'.format(
-            optimizer.cur_scale if args.deepspeed else optimizer.loss_scale)
-    print_rank_0(log_string)
-    if summary_writer is not None:
-        summary_writer.add_scalar(f'Train/lr', lr, step)
-        summary_writer.add_scalar(f'Train/train_loss', loss, step)
-        summary_writer.add_scalar(f'Train/elapsed_time', elapsed_time, step)
-
-
-def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
-    string = ' validation loss at {} | '.format(prefix)
-    string += 'LM loss: {:.6E} | '.format(loss)
-    string += 'LM PPL: {:.6E}'.format(ppl)
-    length = len(string) + 1
-    print_rank_0('-' * 100)
-    print_rank_0('-' * length)
-    print_rank_0(string)
-    print_rank_0('-' * length)
-    if summary_writer is not None:
-        summary_writer.add_scalar(f'Train/valid_ppl', ppl, step)
-        summary_writer.add_scalar(f'Train/valid_loss', loss, step)
-
-
-def train(model, optimizer, lr_scheduler,
-          train_data_iterator, val_data_iterator, timers, args, summary_writer=None):
-    """Train the model."""
-    # Turn on training mode which enables dropout.
-    model.train()
-
-    # Tracking loss.
-    total_lm_loss = 0.0
-    total_img_loss = total_txt_loss = 0.0
-
-    # Iterations.
-    skipped_iters = 0
-
-    timers('interval time').start()
-    report_memory_flag = True
-    mems = []
-    while args.iteration < args.train_iters:
-
-        if args.iteration % 100 == 0:
-            new_loaders = detect_new_datasets(args)
-            if new_loaders is not None:
-                print(f'Loading new datasets ... Now we train models on {args.train_data}.')
-                train_data_iterator = iter(new_loaders[0])
-                val_data_iterator = iter(new_loaders[1])
-                # TODO close the original
-
-
-        lm_loss, skipped_iter, mems, img_loss, txt_loss, img_loss2 = train_step(train_data_iterator,
-                                           model,
-                                           optimizer,
-                                           lr_scheduler,
-                                           args, timers, mems)
-        skipped_iters += skipped_iter
-        args.iteration += 1
-
-        # Update losses.
-        total_lm_loss += lm_loss.data.detach().float()
-        total_img_loss += img_loss.data.detach().float()
-        total_txt_loss += txt_loss.data.detach().float()
-
-        # Logging.
-        if args.iteration % args.log_interval == 0:
-            learning_rate = optimizer.param_groups[0]['lr']
-            avg_lm_loss = total_lm_loss.item() / args.log_interval
-            # average img & txt loss
-            avg_img_loss = total_img_loss.item() / args.log_interval
-            avg_txt_loss = total_txt_loss.item() / args.log_interval
-
-            elapsed_time = timers('interval time').elapsed()
-            report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss,
-                                    elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args,
-                                    avg_img_loss, avg_txt_loss, img_loss2)
-            total_lm_loss = 0.0
-            total_img_loss = 0.0
-            total_txt_loss = 0.0
-            if report_memory_flag:
-                report_memory('after {} iterations'.format(args.iteration))
-                report_memory_flag = False
-            # if USE_TORCH_DDP:
-            #     timers.log(['forward', 'backward', 'optimizer',
-            #                 'batch generator', 'data loader'],
-            #                normalizer=args.log_interval)
-            # else:
-            timers.log(['forward', 'backward', 'allreduce', 'optimizer',
-                            'batch generator', 'data loader'],
-                        normalizer=args.log_interval)
-        # Checkpointing
-        if args.save and args.save_interval and args.iteration % args.save_interval == 0:
-            save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args)
-
-        # Evaluation
-        if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
-            prefix = 'iteration {}'.format(args.iteration)
-            evaluate_and_print_results(
-                prefix, val_data_iterator, model, args, timers, False, step=args.iteration, summary_writer=summary_writer)
-
-        if args.exit_interval and args.iteration % args.exit_interval == 0:
-            torch.distributed.barrier()
-            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
-            rank = torch.distributed.get_rank()
-            print('rank: {} | time: {} | exiting the program at iteration {}'.
-                  format(rank, time_str, args.iteration), flush=True)
-            exit()
-
-    return args.iteration, skipped_iters
-
-
-def evaluate(data_iterator, model, args, timers, verbose=False):
-    """Evaluation."""
-
-    # Turn on evaluation mode which disables dropout.
-    model.eval()
-
-    total_lm_loss = 0
-    mems = []
-    # with open('grad_scale_fp32.txt', 'w') as fout: 
-    with torch.no_grad():
-        iteration = 0
-        while iteration < args.eval_iters:
-            iteration += 1
-            if verbose and iteration % args.log_interval == 0:
-                print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
-            # Forward evaluation.
-            lm_loss, mems, img_loss, txt_loss, img_loss2 = forward_step(data_iterator, model, args, timers, mems=mems)
-
-            # (lm_loss).backward()
-            # for name, param in model.named_parameters():
-            #     v_max = param.data.abs().max().item()
-            #     v_mean = param.data.abs().mean().item()
-            #     fout.write(f'name: {name}, v_max: {v_max}, v_mean: {v_mean}\n')
-            #     if param.grad is not None:
-            #         g_max = param.grad.max().item()
-            #         g_zero_bool = param.grad.abs() == 0
-            #         g_zero = (g_zero_bool.sum() / param.grad.numel()).item()
-            #         g_nz_mean = param.grad[~g_zero_bool].abs().mean().item()
-            #         fout.write(f'g_max: {g_max}, g_zero: {g_zero}, g_nz_mean: {g_nz_mean}\n')
-            # fout.flush()
-            # import pdb;pdb.set_trace()
-            '''when contiguous memory optimizations are enabled, the buffers
-            allocated by the optimizations are deallocated during backward pass
-            in the absence of backward pass the buffers should be reset after each
-            forward pass'''
-            if args.deepspeed and args.deepspeed_activation_checkpointing:
-                deepspeed.checkpointing.reset()
-
-            # Reduce across processes.
-            # if isinstance(model, DDP):
-            #     torch.distributed.all_reduce(lm_loss.data)
-            #     lm_loss.data = lm_loss.data / args.world_size
-
-            total_lm_loss += lm_loss.data.detach().float().item()
-
-    # Move model back to the train mode.
-    model.train()
-
-    total_lm_loss /= args.eval_iters
-    return total_lm_loss
-
-
-def evaluate_and_print_results(prefix, data_iterator, model,
-                               args, timers, verbose=False, step=None, summary_writer=None):
-    """Helper function to evaluate and dump results on screen."""
-    # import line_profiler
-    # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward)
-    # profile.enable()
-    # torch.cuda.empty_cache()
-    lm_loss = evaluate(data_iterator, model, args, timers, verbose)
-    # profile.disable()  # 停止分析
-    # import sys
-    # profile.print_stats(sys.stdout)
-    lm_ppl = math.exp(min(20, lm_loss))
-    report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step)
-
-    return lm_loss
-
-
-'''
-    Optional DeepSpeed Activation Checkpointing features
-    Gives access to partition activations, contiguous memory optimizations
-    and cpu checkpointing.
-
-    Activation checkpoint requires keep track of the random states
-    and setting the random seed for each MP process. Megatron uses
-    mpu.get_cuda_rng_tracker and mpu.model_parallel_cuda_manual_seed
-    for keeping track of the random states and setting the random seeds.
-    Since they are used in places outside of activation checkpointing,
-    we overwrite them to maintain consistency.
-
-    This must be done before all the calls to mpu.model_parallel_cuda_manual_seed
-    '''
-
-
-def set_deepspeed_activation_checkpointing(args):
-    deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
-    mpu.checkpoint = deepspeed.checkpointing.checkpoint
-    mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
-    mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
-
-
-def initialize_distributed(args):
-    """Initialize torch.distributed."""
-
-    # Manually set the device ids.
-    device = args.rank % torch.cuda.device_count()
-    if args.local_rank is not None:
-        device = args.local_rank
-    torch.cuda.set_device(device)
-    # Call the init process
-    init_method = 'tcp://'
-    master_ip = os.getenv('MASTER_ADDR', 'localhost')
-    master_port = os.getenv('MASTER_PORT', '6000')
-    init_method += master_ip + ':' + master_port
-    torch.distributed.init_process_group(
-        backend=args.distributed_backend,
-        world_size=args.world_size, rank=args.rank,
-        init_method=init_method)
-
-    # Set the model-parallel / data-parallel communicators.
-    mpu.initialize_model_parallel(args.model_parallel_size)
-
-    # Optional DeepSpeed Activation Checkpointing Features
-    #
-    if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
-        set_deepspeed_activation_checkpointing(args)
-
-
-def set_random_seed(seed):
-    """Set random seed for reproducability."""
-
-    if seed is not None and seed > 0:
-        random.seed(seed)
-        np.random.seed(seed)
-        torch.manual_seed(seed)
-        mpu.model_parallel_cuda_manual_seed(seed)
-
-
-def get_train_val_test_data(args):
-    """Load the data on rank zero and boradcast number of tokens to all GPUS."""
-
-    (train_data, val_data, test_data) = (None, None, None)
-
-    # Data loader only on rank 0 of each model parallel group.
-    if mpu.get_model_parallel_rank() == 0:
-        train_data, val_data, test_data = make_loaders(args)
-        num_tokens = get_tokenizer().num_tokens
-
-        before = num_tokens
-        after = before
-        multiple = args.make_vocab_size_divisible_by * \
-                   mpu.get_model_parallel_world_size()
-        while (after % multiple) != 0:
-            after += 1
-        print_rank_0('> padded vocab (size: {}) with {} dummy '
-                     'tokens (new size: {})'.format(
-                         before, after - before, after))
-        token_counts = torch.cuda.LongTensor(
-            [after, int(args.do_train), int(args.do_valid), int(args.do_test)])
-    else:
-        token_counts = torch.cuda.LongTensor([0, 0, 0, 0])
-    # Broadcast num tokens.
-    torch.distributed.broadcast(token_counts,
-                                mpu.get_model_parallel_src_rank(),
-                                group=mpu.get_model_parallel_group())
-    num_tokens = token_counts[0].item()
-    args.do_train = token_counts[1].item()
-    args.do_valid = token_counts[2].item()
-    args.do_test = token_counts[3].item()
-
-    return train_data, val_data, test_data, num_tokens
-
-
-def main():
-    """Main training program."""
-
-    # Disable CuDNN.
-    torch.backends.cudnn.enabled = False
-    # Timer.
-    timers = Timers()
+    
+    return loss, {}
 
-    # Arguments.
+if __name__ == '__main__':
     args = get_args()
-    if args.load:
-        args.experiment_name = os.path.basename(os.path.normpath(args.load))
-    else:
-        args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M")
-    if args.save:
-        args.save = os.path.join(args.save, args.experiment_name)
-    # Pytorch distributed.
-    initialize_distributed(args)
-
-    # Random seeds for reproducability.
-    set_random_seed(args.seed)
-
-    # init tokenizer
-    tokenizer = get_tokenizer(args)
-
-    # Data stuff.
-    train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args)
-
-    # Model, optimizer, and learning rate.
-    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
-
-    if args.load is not None:
-        if args.fast_load:
-            args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
-        else:
-            with FileLock("/root/checkpoint_lock", timeout=-1):
-                args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
-    else:
-        args.iteration = 0
-    torch.distributed.barrier()
-
-    summary_writer = None
-    if torch.distributed.get_rank() == 0:
-        if args.finetune:
-            print('Finetune CogView model')
-        else:
-            print('Pretrain CogView model')
-        print_args(args)
-        summary_writer = get_sample_writer(base=args.summary_dir, name=args.experiment_name, iteration=args.iteration)
-
-    # Resume data loader if necessary.
-    if args.resume_dataloader:
-        if train_data is not None:
-            train_data.batch_sampler.start_iter = args.iteration % \
-                                                  len(train_data)
-        if val_data is not None:
-            start_iter_val = (args.train_iters // args.save_interval) * \
-                             args.eval_interval
-            val_data.batch_sampler.start_iter = start_iter_val % \
-                                                len(val_data)
-    if train_data is not None:
-        train_data_iterator = iter(train_data)
-    else:
-        train_data_iterator = None
-    if val_data is not None:
-        val_data_iterator = iter(val_data)
-    else:
-        val_data_iterator = None
-
-    # TODO: figure out how to properly set this especially when resuming training
-    iteration = 0
-    if args.train_iters > 0:
-        if args.do_train:
-            with ExitStack() as stack:
-                def save_on_exit(args_, model_, optimizer_, lr_scheduler_):
-                    save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_)
-                # stack.callback(save_on_exit, args, model, optimizer, lr_scheduler)
-                iteration, skipped = train(model, optimizer,
-                                           lr_scheduler,
-                                           train_data_iterator,
-                                           val_data_iterator,
-                                           timers, args, summary_writer=summary_writer)
-
-        if args.do_valid:
-            prefix = 'the end of training for val data'
-            val_loss = evaluate_and_print_results(prefix, val_data_iterator,
-                                                  model, args, timers, False)
-
-    if args.save and iteration != 0:
-        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
-
-    if test_data is not None:
-        test_data_iterator = iter(test_data)
-    else:
-        test_data_iterator = None
-
-    if args.do_test:
-        # Run on test data.
-        prefix = 'the end of training for test data'
-        evaluate_and_print_results(prefix, test_data_iterator,
-                                   model, args, timers, True)
-
-def seed_torch(seed=1029):
-    random.seed(seed)
-    np.random.seed(seed)
-    torch.manual_seed(seed)
-    torch.cuda.manual_seed(seed)
-    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
-    torch.backends.cudnn.benchmark = False
-    torch.backends.cudnn.deterministic = True
-    torch.backends.cudnn.enabled = False
-
-if __name__ == "__main__":
-    torch.backends.cuda.matmul.allow_tf32 = False
-    main()
+    main(args, model_cls=BaseModel, forward_step=forward_step)
diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py
new file mode 100644
index 0000000..95d8ac4
--- /dev/null
+++ b/training/deepspeed_training.py
@@ -0,0 +1,583 @@
+# coding=utf-8
+# Rewrite by Ming Ding, Tsinghua University
+# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import random
+import math
+import numpy as np
+import torch
+from collections import defaultdict
+from datetime import datetime
+from contextlib import ExitStack
+
+import torch.distributed as dist
+import deepspeed
+
+from .learning_rates import AnnealingLR
+from .model_io import load_checkpoint, save_checkpoint
+
+from utils import Timers
+from utils import report_memory
+from utils import print_args
+from utils import print_rank_0
+from utils import get_sample_writer
+
+import mpu
+from data_utils import make_loaders, get_tokenizer
+
+
+
+def main(args, model_cls, forward_step, init_step=None):
+    """Main training program."""
+    hooks = {
+        'forward_step': forward_step,
+        'init_step': init_step
+        }
+    
+    torch.backends.cuda.matmul.allow_tf32 = False
+    torch.backends.cudnn.enabled = False # Disable CuDNN.
+    set_random_seed(args.seed) # Random seeds for reproducability.
+    timers = Timers() # Timer.
+    
+    # Experiment Name
+    if args.load and args.mode == 'pretrain': # continue training
+        args.experiment_name = os.path.basename(os.path.normpath(args.load))
+    else:
+        args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M")
+        
+    # Pytorch distributed.
+    initialize_distributed(args)
+    # init tokenizer
+    tokenizer = get_tokenizer(args)
+    # Data stuff.
+    train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args)
+
+    # Model, optimizer, and learning rate.
+    model, optimizer = setup_model_and_optimizer(args, model_cls)
+
+    # Config model IO
+    if args.load is not None:
+        args.iteration = load_checkpoint(model, optimizer, args)
+        # if we don't load optim_states, filelock is no more needed.
+        # with FileLock("/root/checkpoint_lock", timeout=-1):
+        #     args.iteration = load_checkpoint(model, optimizer, args)
+    else:
+        args.iteration = 0
+    if args.save:
+        args.save = os.path.join(args.save, args.experiment_name)
+    torch.distributed.barrier()
+    
+    # initialize lr scheduler
+    lr_scheduler = get_learning_rate_scheduler(optimizer, args, args.iteration)
+
+    summary_writer = None
+    if torch.distributed.get_rank() == 0:
+        if args.mode == 'pretrain':
+            print('Pretraining or Continuing training the Model...')
+        elif args.mode == 'finetune':
+            print('Finetuning Model...')
+        print_args(args)
+        summary_writer = get_sample_writer(base=args.summary_dir, name=args.experiment_name, iteration=args.iteration)
+
+    # Resume data loader if necessary.
+    if args.resume_dataloader:
+        if train_data is not None:
+            train_data.batch_sampler.start_iter = args.iteration % len(train_data)
+        if val_data is not None:
+            start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval
+            val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
+    if train_data is not None:
+        train_data_iterator = iter(train_data)
+    else:
+        train_data_iterator = None
+    if val_data is not None:
+        val_data_iterator = iter(val_data)
+    else:
+        val_data_iterator = None
+        
+    # init hook before training
+    if hooks['init_func'] is not None:
+        hooks['init_func'](args, model, optimizer)
+
+    # training 
+    iteration = 0
+    if args.train_iters > 0:
+        if args.do_train:
+            with ExitStack() as stack:
+                def save_on_exit(args_, model_, optimizer_, lr_scheduler_):
+                    save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_)
+                iteration, skipped = train(model, optimizer,
+                    lr_scheduler,
+                    train_data_iterator,
+                    val_data_iterator,
+                    timers, args, summary_writer=summary_writer,
+                    hooks=hooks
+                    )
+        if args.do_valid:
+            prefix = 'the end of training for val data'
+            val_loss = evaluate_and_print_results(prefix, val_data_iterator,
+                model, args, timers, False)
+
+    # final save
+    if args.save and iteration != 0: # TODO save
+        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
+
+    # final testing
+    if args.do_test and test_data is not None:
+        prefix = 'the end of training for test data'
+        evaluate_and_print_results(prefix, iter(test_data),
+            model, args, timers, True)
+
+def get_model(args, model_cls):
+    """Build the model."""
+
+    print_rank_0(f'building {model_cls.__name__} model ...')
+    model = model_cls(args)
+
+    if mpu.get_data_parallel_rank() == 0:
+        print(' > number of parameters on model parallel rank {}: {}'.format(
+            mpu.get_model_parallel_rank(),
+            sum([p.nelement() for p in model.parameters()])), flush=True)
+
+    if args.fp16:
+        model.half()
+    model.cuda(torch.cuda.current_device())
+
+    return model
+
+def setup_model_and_optimizer(args, model_cls):
+    """Setup model and optimizer."""
+
+    model = get_model(args, model_cls)
+    
+    model.disable_untrainable_params() # mark trainable params
+
+    param_groups = get_optimizer_param_groups(model)
+
+    if args.train_data is not None:
+        if args.deepspeed:
+            print_rank_0("DeepSpeed is enabled.")
+            model, optimizer, _, _ = deepspeed.initialize(
+                model=model,
+                model_parameters=param_groups,
+                args=args,
+                mpu=mpu,
+                dist_init_required=False
+            )
+        else:
+            raise ValueError('Currently, we only support training with deepspeed.')
+    else:
+        optimizer = None
+
+    return model, optimizer
+
+
+def get_params_for_weight_decay_optimization(module):
+    
+    weight_decay_params = {'params': []}
+    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
+    for module_ in module.modules():
+        if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
+            no_weight_decay_params['params'].extend(
+                [p for p in list(module_._parameters.values())
+                 if p is not None and p.requires_grad])
+        else:
+            weight_decay_params['params'].extend(
+                [p for n, p in list(module_._parameters.items())
+                 if p is not None and n != 'bias' and p.requires_grad])
+            no_weight_decay_params['params'].extend(
+                [p for n, p in list(module_._parameters.items())
+                 if p is not None and n == 'bias' and p.requires_grad])
+    return weight_decay_params, no_weight_decay_params
+
+def get_optimizer_param_groups(model):
+    # Build parameter groups (weight decay and non-decay).
+    if hasattr(model, 'module'):
+        model = model.module
+    param_groups = get_params_for_weight_decay_optimization(model) # TODO move to here
+    # Add model parallel attribute if it is not set.
+    for param_group in param_groups:
+        for param in param_group['params']:
+            if not hasattr(param, 'model_parallel'):
+                param.model_parallel = False
+    return param_groups
+
+def get_learning_rate_scheduler(optimizer, iteration, args, 
+                                auto_warmup_steps=50, auto_warmup_rate=0.05):
+    """Build the learning rate scheduler."""
+
+    # Add linear learning rate scheduler.
+    if args.lr_decay_iters is not None:
+        num_iters = args.lr_decay_iters
+    else:
+        num_iters = args.train_iters
+    num_iters = max(1, num_iters)
+    if args.mode == 'pretrain':
+        init_step = max(iteration-auto_warmup_steps, 0)
+    elif args.mode == 'finetune':
+        init_step = 0
+    # If init_step <= current_steps <= init_step + auto_warmup_steps,
+    # lr = auto_warmup_rate * args.lr.
+    # This overrides other rules.
+    warmup_iter = args.warmup * num_iters
+    lr_scheduler = AnnealingLR(optimizer,
+        start_lr=args.lr,
+        warmup_iter=warmup_iter,
+        num_iters=num_iters,
+        decay_style=args.lr_decay_style,
+        last_iter=init_step,
+        decay_ratio=args.lr_decay_ratio,
+        auto_warmup_steps=auto_warmup_steps,
+        auto_warmup_rate=auto_warmup_rate
+        )
+
+    return lr_scheduler
+
+
+def train(model, optimizer, lr_scheduler,
+        train_data_iterator, val_data_iterator, timers, args, 
+        summary_writer=None, hooks={}):
+    """Train the model."""
+    # Turn on training mode which enables dropout.
+    model.train()
+
+    # Tracking loss.
+    total_lm_loss = 0.0
+    total_metrics = defaultdict(float)
+
+    # Iterations.
+    skipped_iters = 0
+
+    timers('interval time').start()
+    report_memory_flag = True
+    while args.iteration < args.train_iters:
+
+        lm_loss, skipped_iter, metrics = train_step(train_data_iterator,
+                                        model,
+                                        optimizer,
+                                        lr_scheduler,
+                                        args, timers, hooks=hooks)
+        skipped_iters += skipped_iter
+        args.iteration += 1
+
+        # Update losses.
+        total_lm_loss += lm_loss.data.detach().float()
+        for name in metrics:
+            total_metrics[name] += metrics[name].data.detach().float().item()
+
+        # Logging.
+        if args.iteration % args.log_interval == 0:
+            learning_rate = optimizer.param_groups[0]['lr']
+            avg_lm_loss = total_lm_loss.item() / args.log_interval
+            # average img & txt loss
+            avg_metrics = {}
+            for key in total_metrics:
+                avg_metrics[key] = total_metrics[key] / args.log_interval
+
+            elapsed_time = timers('interval time').elapsed()
+            report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss,
+                                    elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args,
+                                    avg_metrics)
+            total_lm_loss = 0.0
+            total_metrics = defaultdict(float)
+            if report_memory_flag:
+                report_memory('after {} iterations'.format(args.iteration))
+                report_memory_flag = False
+
+            timers.log(['forward', 'backward', 'allreduce', 'optimizer',
+                            'batch generator', 'data loader'],
+                        normalizer=args.log_interval)
+        # Checkpointing
+        if args.save and args.save_interval and args.iteration % args.save_interval == 0:
+            save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args)
+
+        # Evaluation
+        if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
+            prefix = 'iteration {}'.format(args.iteration)
+            evaluate_and_print_results(
+                prefix, val_data_iterator, model, args, timers, False, step=args.iteration, summary_writer=summary_writer, hooks=hooks)
+
+        if args.exit_interval and args.iteration % args.exit_interval == 0:
+            torch.distributed.barrier()
+            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+            rank = torch.distributed.get_rank()
+            print('rank: {} | time: {} | exiting the program at iteration {}'.
+                  format(rank, time_str, args.iteration), flush=True)
+            exit()
+
+    return args.iteration, skipped_iters
+
+
+def train_step(data_iterator, model, optimizer, lr_scheduler,
+               args, timers, hooks={}):
+    """Single training step."""
+    forward_step = hooks['forward_step']
+    
+    while True:
+        # Forward model for one step.
+        timers('forward').start()
+        lm_loss, metrics = forward_step(data_iterator, model, args, timers)
+        timers('forward').stop()
+
+        # Check nan or inf in forward, preventing it from interfering loss scaler,
+        # and all reduce metrics by the way
+        loss_checker = lm_loss.detach().item()
+        for name in metrics:
+            metrics[name] = metrics[name].detach()
+            torch.distributed.all_reduce(metrics[name].data)
+            metrics[name].data /= args.world_size
+            loss_checker += metrics[name]
+        if loss_checker.isnan().any() or loss_checker.isinf().any():
+            print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
+            return lm_loss.detach(), 1, metrics
+
+        # Calculate gradients, reduce across processes, and clip.
+        timers('backward').start()
+        lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers)
+        timers('backward').stop()
+        # Update parameters.
+        skipped_iter, complete = 0, False
+        timers('optimizer').start()
+        if args.deepspeed:
+            if model.is_gradient_accumulation_boundary():
+                model.step()
+                complete = True
+                if not (args.fp16 and optimizer.overflow):
+                    lr_scheduler.step()
+                else:
+                    skipped_iter = 1
+            else:
+                model.step()
+        else:
+            raise ValueError('Currently, we only support training with deepspeed.')
+        timers('optimizer').stop()
+        if complete:
+            break
+    return lm_loss_reduced, skipped_iter, metrics
+
+def backward_step(optimizer, model, loss, args, timers):
+    """Backward step."""
+
+    # Backward pass.
+    if args.deepspeed:
+        model.backward(loss)
+    else:
+        raise ValueError('Currently, we only support training with deepspeed.')
+
+    reduced_losses = loss.view(1)
+
+    # Reduce losses for logging
+    torch.distributed.all_reduce(reduced_losses.data)
+    reduced_losses.data = reduced_losses.data / args.world_size
+
+    if args.deepspeed:
+        # DeepSpeed backward propagation already addressed all reduce communication.
+        # Reset the timer to avoid breaking timer logs below.
+        timers('allreduce').reset()
+
+    return reduced_losses
+
+def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
+    """Evaluation."""
+    forward_step = hooks['forward_step']
+
+    # Turn on evaluation mode which disables dropout.
+    model.eval()
+
+    total_lm_loss = 0
+    with torch.no_grad():
+        iteration = 0
+        while iteration < args.eval_iters:
+            iteration += 1
+            if verbose and iteration % args.log_interval == 0:
+                print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
+            # Forward evaluation.
+            lm_loss, metrics = forward_step(data_iterator, model, args, timers)
+            '''when contiguous memory optimizations are enabled, the buffers
+            allocated by the optimizations are deallocated during backward pass
+            in the absence of backward pass the buffers should be reset after each
+            forward pass'''
+            if args.deepspeed and args.deepspeed_activation_checkpointing:
+                deepspeed.checkpointing.reset()
+            total_lm_loss += lm_loss.data.detach().float().item()
+
+    # Move model back to the train mode.
+    model.train()
+
+    total_lm_loss /= args.eval_iters
+    return total_lm_loss
+
+def evaluate_and_print_results(prefix, data_iterator, model,
+                            args, timers, verbose=False, step=None, summary_writer=None, hooks={}):
+    """Helper function to evaluate and dump results on screen."""
+    # import line_profiler
+    # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward)
+    # profile.enable()
+    # torch.cuda.empty_cache()
+    lm_loss = evaluate(data_iterator, model, args, timers, verbose, hooks=hooks)
+    # profile.disable()
+    # import sys
+    # profile.print_stats(sys.stdout)
+    lm_ppl = math.exp(min(20, lm_loss))
+    report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step)
+
+    return lm_loss
+
+def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, avg_metrics):
+    log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step)
+    log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time)
+    log_string += ' learning rate {:.3E} |'.format(lr)
+    log_string += ' lm loss {:.6E} |'.format(loss)
+    for key in avg_metrics:
+        log_string += ' {} {:.6E} |'.format(key, avg_metrics[key])
+    if args.fp16:
+        log_string += ' loss scale {:.1f} |'.format(
+            optimizer.cur_scale if args.deepspeed else optimizer.loss_scale)
+    print_rank_0(log_string)
+    if summary_writer is not None:
+        summary_writer.add_scalar(f'Train/lr', lr, step)
+        summary_writer.add_scalar(f'Train/train_loss', loss, step)
+        summary_writer.add_scalar(f'Train/elapsed_time', elapsed_time, step)
+
+
+def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
+    string = ' validation loss at {} | '.format(prefix)
+    string += 'LM loss: {:.6E} | '.format(loss)
+    string += 'LM PPL: {:.6E}'.format(ppl)
+    length = len(string) + 1
+    print_rank_0('-' * 100)
+    print_rank_0('-' * length)
+    print_rank_0(string)
+    print_rank_0('-' * length)
+    if summary_writer is not None:
+        summary_writer.add_scalar(f'Train/valid_ppl', ppl, step)
+        summary_writer.add_scalar(f'Train/valid_loss', loss, step)
+        
+
+'''
+    Optional DeepSpeed Activation Checkpointing features
+    Gives access to partition activations, contiguous memory optimizations
+    and cpu checkpointing.
+
+    Activation checkpoint requires keep track of the random states
+    and setting the random seed for each MP process. Megatron uses
+    mpu.get_cuda_rng_tracker and mpu.model_parallel_cuda_manual_seed
+    for keeping track of the random states and setting the random seeds.
+    Since they are used in places outside of activation checkpointing,
+    we overwrite them to maintain consistency.
+
+    This must be done before all the calls to mpu.model_parallel_cuda_manual_seed
+    '''
+
+def set_deepspeed_activation_checkpointing(args):
+    deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
+    mpu.checkpoint = deepspeed.checkpointing.checkpoint
+    mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
+    mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
+
+
+def initialize_distributed(args):
+    """Initialize torch.distributed."""
+
+    # Manually set the device ids.
+    device = args.rank % torch.cuda.device_count()
+    if args.local_rank is not None:
+        device = args.local_rank
+    torch.cuda.set_device(device)
+    # Call the init process
+    init_method = 'tcp://'
+    master_ip = os.getenv('MASTER_ADDR', 'localhost')
+    master_port = os.getenv('MASTER_PORT', '6000')
+    init_method += master_ip + ':' + master_port
+    torch.distributed.init_process_group(
+        backend=args.distributed_backend,
+        world_size=args.world_size, rank=args.rank,
+        init_method=init_method)
+
+    # Set the model-parallel / data-parallel communicators.
+    mpu.initialize_model_parallel(args.model_parallel_size)
+
+    # Optional DeepSpeed Activation Checkpointing Features
+    if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
+        set_deepspeed_activation_checkpointing(args)
+
+
+def set_random_seed(seed):
+    """Set random seed for reproducability."""
+
+    if seed is not None and seed > 0:
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.manual_seed(seed)
+        mpu.model_parallel_cuda_manual_seed(seed)
+
+
+def get_train_val_test_data(args):
+    """Load the data on rank zero and boradcast number of tokens to all GPUS."""
+
+    (train_data, val_data, test_data) = (None, None, None)
+
+    # Data loader only on rank 0 of each model parallel group.
+    if mpu.get_model_parallel_rank() == 0:
+        train_data, val_data, test_data = make_loaders(args)
+        num_tokens = get_tokenizer().num_tokens
+
+        before = num_tokens
+        after = before
+        multiple = args.make_vocab_size_divisible_by * \
+                   mpu.get_model_parallel_world_size()
+        while (after % multiple) != 0:
+            after += 1
+        print_rank_0('> padded vocab (size: {}) with {} dummy '
+                     'tokens (new size: {})'.format(
+                         before, after - before, after))
+        token_counts = torch.cuda.LongTensor(
+            [after, int(args.do_train), int(args.do_valid), int(args.do_test)])
+    else:
+        token_counts = torch.cuda.LongTensor([0, 0, 0, 0])
+    # Broadcast num tokens.
+    torch.distributed.broadcast(token_counts,
+                                mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    num_tokens = token_counts[0].item()
+    args.do_train = token_counts[1].item()
+    args.do_valid = token_counts[2].item()
+    args.do_test = token_counts[3].item()
+
+    return train_data, val_data, test_data, num_tokens
+
+def see_memory_usage(message, force=False):
+    if not force:
+        return
+    dist.barrier()
+    if dist.get_rank() == 0:
+        print(message)
+        print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes")
+        print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes")
+        print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes")
+        print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes")
+        print(" ")
+
+def seed_torch(seed=1029):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
+    torch.backends.cudnn.benchmark = False
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.enabled = False
+
diff --git a/training/learning_rates.py b/training/learning_rates.py
new file mode 100755
index 0000000..fd325c7
--- /dev/null
+++ b/training/learning_rates.py
@@ -0,0 +1,81 @@
+# coding=utf-8
+# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+import math
+
+
+class AnnealingLR(_LRScheduler):
+    """Anneals the learning rate from start to zero along a cosine curve."""
+
+    DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
+
+    def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5, auto_warmup_steps=50, auto_warmup_rate=0.05):
+        assert warmup_iter <= num_iters
+        self.optimizer = optimizer
+        self.start_lr = start_lr
+        self.warmup_iter = warmup_iter
+        self.init_step = last_iter
+        self.num_iters = last_iter + 1
+        self.end_iter = num_iters
+        self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None
+        self.decay_ratio = 1 / decay_ratio
+        self.auto_warmup_steps = auto_warmup_steps
+        self.auto_warmup_rate = auto_warmup_rate
+        self.step(self.num_iters)
+        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+            print(f'learning rate decaying style {self.decay_style}, ratio {self.decay_ratio}')
+
+    def get_lr(self):
+        if self.num_iters <= self.init_step + self.auto_warmup_steps:
+            return float(self.start_lr) * self.auto_warmup_rate
+        
+        if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
+            return float(self.start_lr) * self.num_iters / self.warmup_iter
+        else:
+            if self.decay_style == self.DECAY_STYLES[0]:
+                return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/real_end_iter)
+            elif self.decay_style == self.DECAY_STYLES[1]:
+                decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / real_end_iter)
+                return self.start_lr / self.decay_ratio * (
+                        (math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1)
+            elif self.decay_style == self.DECAY_STYLES[2]:
+                #TODO: implement exponential decay
+                return self.start_lr
+            else:
+                return self.start_lr
+
+    def step(self, step_num=None):
+        if step_num is None:
+            step_num = self.num_iters + 1
+        self.num_iters = step_num
+        new_lr = self.get_lr()
+        for group in self.optimizer.param_groups:
+            group['lr'] = new_lr
+
+    def state_dict(self):
+        sd = {
+                'start_lr': self.start_lr,
+                'warmup_iter': self.warmup_iter,
+                'num_iters': self.num_iters,
+                'decay_style': self.decay_style,
+                'end_iter': self.end_iter,
+                'decay_ratio': self.decay_ratio
+        }
+        return sd
+
+    def load_state_dict(self, sd):
+        pass # disable this 
diff --git a/training/model_io.py b/training/model_io.py
new file mode 100644
index 0000000..fbc2254
--- /dev/null
+++ b/training/model_io.py
@@ -0,0 +1,162 @@
+
+# -*- encoding: utf-8 -*-
+'''
+@File    :   model_io.py
+@Time    :   2021/10/05 18:39:55
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import numpy as np
+
+import mpu
+from utils import print_rank_0
+
+def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
+    if release:
+        d = 'release'
+    else:
+        d = '{:d}'.format(iteration)
+    if zero:
+        dp_rank = mpu.get_data_parallel_rank()
+        d += '_zero_dp_rank_{}'.format(dp_rank)
+    return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
+
+def get_checkpoint_tracker_filename(checkpoints_path):
+    return os.path.join(checkpoints_path, 'latest')
+
+def save_checkpoint(iteration, model, optimizer,
+                    lr_scheduler, args):
+    """Save a model checkpoint."""
+    if args.deepspeed:
+        save_ds_checkpoint(iteration, model, lr_scheduler, args)
+    else:
+        raise ValueError("training without deepspeed is not supported.")
+    # Wait so everyone is done (necessary)
+    torch.distributed.barrier()
+    # And update the latest iteration
+    if torch.distributed.get_rank() == 0:
+        tracker_filename = get_checkpoint_tracker_filename(args.save)
+        with open(tracker_filename, 'w') as f:
+            f.write(str(iteration))
+    # Wait so everyone is done (not necessary)
+    torch.distributed.barrier()
+
+
+def save_ds_checkpoint(iteration, model, lr_scheduler, args):
+    """Save a model checkpoint."""
+
+    sd = {}
+    sd['iteration'] = iteration
+    if lr_scheduler is not None:
+        sd['client_lr_scheduler'] = lr_scheduler.state_dict()
+    # rng states.
+    if not args.no_save_rng:
+        sd['random_rng_state'] = random.getstate()
+        sd['np_rng_state'] = np.random.get_state()
+        sd['torch_rng_state'] = torch.get_rng_state()
+        sd['cuda_rng_state'] = torch.cuda.get_rng_state()
+        sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
+    save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd)
+    
+def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True):
+    
+    os.makedirs(save_dir, exist_ok=True)
+    # Ensure tag is a string
+    tag = str(tag)
+    # Ensure checkpoint tag is consistent across ranks
+    model._checkpoint_tag_validation(tag)
+    # Real save via deepspeed
+    model._create_checkpoint_file(save_dir, tag, False)
+    model._save_checkpoint(save_dir, tag, client_state=client_state)
+    # Save latest checkpoint tag
+    if save_latest:
+        with open(os.path.join(save_dir, 'latest'), 'w') as fd:
+            fd.write(tag)
+
+    return True
+
+
+def get_checkpoint_iteration(args):
+    # Read the tracker file and set the iteration.
+    tracker_filename = get_checkpoint_tracker_filename(args.load)
+    if not os.path.isfile(tracker_filename):
+        print_rank_0('WARNING: could not find the metadata file {} '.format(
+            tracker_filename))
+        print_rank_0('    will not load any checkpoints and will start from random')
+        return 0, False, False
+    iteration = 0
+    release = False
+    with open(tracker_filename, 'r') as f:
+        metastring = f.read().strip()
+        try:
+            iteration = int(metastring)
+        except ValueError:
+            release = metastring == 'release'
+            if not release:
+                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
+                    tracker_filename))
+                exit()
+    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
+        tracker_filename)
+
+    return iteration, release, True
+
+def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=True):
+    """Load a model checkpoint."""
+
+    iteration, release, success = get_checkpoint_iteration(args)
+    if not success:
+        return 0
+    
+    checkpoint_name = get_checkpoint_name(args.load, iteration, release)
+    if mpu.get_data_parallel_rank() == 0:
+            print('global rank {} is loading checkpoint {}'.format(
+                torch.distributed.get_rank(), checkpoint_name))
+    sd = torch.load(checkpoint_name, map_location='cpu')
+    
+    assert not args.do_train or args.deepspeed
+    if args.deepspeed:
+        module = model.module
+    else: # inference without deepspeed
+        module = model
+        
+    # only load module, other hyperparameters are just for recording.
+    missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
+    if len(unexpected_keys) > 0:
+        print_rank_0(f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
+    if len(missing_keys) > 0:
+        if not args.do_train:
+            raise ValueError(f'Missing keys for inference: {missing_keys}.')
+        else: # new params
+            assert all(name.find('mixins')>0 for name in missing_keys)
+            module.reinit() # initialize mixins
+            model.optimizer.refresh_fp32_params() # restore fp32 weights from module
+
+    # Iterations.
+    if args.mode == 'finetune':
+        iteration = 0
+    elif args.mode == 'pretrain' and not args.no_load_rng: # rng states.
+        try:
+            random.setstate(sd['random_rng_state'])
+            np.random.set_state(sd['np_rng_state'])
+            torch.set_rng_state(sd['torch_rng_state'])
+            torch.cuda.set_rng_state(sd['cuda_rng_state'])
+            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
+        except KeyError:
+            print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
+                         'Specify --no-load-rng or --finetune to prevent '
+                         'attempting to load the random '
+                         'state.'.format(checkpoint_name))
+            exit()
+
+    if mpu.get_data_parallel_rank() == 0:
+        print('  successfully loaded {}'.format(checkpoint_name))
+    del sd
+    return iteration
-- 
GitLab