diff --git a/draw_diff.py b/draw_diff.py
index 6c2d74b5cf57c469752a5fb3c5170574115e6c94..640b06f4ea0adc6a487684590554f1002cc36629 100644
--- a/draw_diff.py
+++ b/draw_diff.py
@@ -1,12 +1,21 @@
 import numpy as np
 import torch
+def entropy(x):
+    a = np.array(x)
+    a = a / a.sum()
+    return - np.sum(a * np.log(a))
+print(entropy([0.9999,0.001]))
 def loadbao(name):
-    ret = []
+    ret1, ret2 = [], []
     with open(name, 'r') as fin:
         for line in fin:
-            a, b = line.split()
-            ret.append(abs(float(b)))
-    return ret
+            a = line.split()
+            aa = [float(x) for x in a[1:5]]
+            b = entropy(aa)
+            c = sum(aa)
+            ret1.append(b)
+            ret2.append(c)
+    return np.array(ret1), np.array(ret2)
 
 def loadlion(name):
     ret1, ret2 = [], []
@@ -33,9 +42,9 @@ transform = transforms.Compose([
 img = torchvision.io.read_image('cat2.jpeg')
 img = transform(img) / 255.
 # a,b = np.array(loadlion('bed6.txt'))
-b = np.array(loadbao('bed6.txt'))
-for t in (b<0.999).nonzero()[0]:
+b, c = np.array(loadbao('bed1.txt'))
+for t in (b>1.35).nonzero()[0]:
     x,y = t // 64, t % 64
     sq(img, x*8, y*8, 7, 7)
 print(b.sum())
-torchvision.utils.save_image(img, 'example_cat6.jpg')
+torchvision.utils.save_image(img, 'example_cat.jpg')
diff --git a/finetune/__init__.py b/finetune/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/fp16/__init__.py b/fp16/__init__.py
deleted file mode 100755
index a2c68a1fa5f49537cdae103be38d97064262ac40..0000000000000000000000000000000000000000
--- a/fp16/__init__.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from .fp16util import (
-    BN_convert_float,
-    network_to_half,
-    prep_param_lists,
-    model_grads_to_master_grads,
-    master_params_to_model_params,
-    tofp16,
-    to_python_float,
-    clip_grad_norm,
-    convert_module,
-    convert_network,
-    FP16Model,
-)
-
-from .fp16 import *
-from .loss_scaler import *
diff --git a/fp16/fp16.py b/fp16/fp16.py
deleted file mode 100755
index c1c6af578417f9105713d2081df611dca69ff875..0000000000000000000000000000000000000000
--- a/fp16/fp16.py
+++ /dev/null
@@ -1,629 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Stable version of apex FP16 Optimizer"""
-import torch
-from torch import nn
-from torch.autograd import Variable
-from torch.nn.parameter import Parameter
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-
-from .loss_scaler import DynamicLossScaler, LossScaler
-from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
-
-FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
-HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
-
-def conversion_helper(val, conversion):
-    """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
-    if not isinstance(val, (tuple, list)):
-        return conversion(val)
-    rtn =  [conversion_helper(v, conversion) for v in val]
-    if isinstance(val, tuple):
-        rtn = tuple(rtn)
-    return rtn
-
-def fp32_to_fp16(val):
-    """Convert fp32 `val` to fp16"""
-    def half_conversion(val):
-        val_typecheck = val
-        if isinstance(val_typecheck, (Parameter, Variable)):
-            val_typecheck = val.data
-        if isinstance(val_typecheck, FLOAT_TYPES):
-            val = val.half()
-        return val
-    return conversion_helper(val, half_conversion)
-
-def fp16_to_fp32(val):
-    """Convert fp16 `val` to fp32"""
-    def float_conversion(val):
-        val_typecheck = val
-        if isinstance(val_typecheck, (Parameter, Variable)):
-            val_typecheck = val.data
-        if isinstance(val_typecheck, HALF_TYPES):
-            val = val.float()
-        return val
-    return conversion_helper(val, float_conversion)
-
-class FP16_Module(nn.Module):
-    def __init__(self, module):
-        super(FP16_Module, self).__init__()
-        self.add_module('module', module.half())
-
-    def forward(self, *inputs, **kwargs):
-        return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
-
-    def state_dict(self, destination=None, prefix='', keep_vars=False):
-        return self.module.state_dict(destination, prefix, keep_vars)
-
-    def load_state_dict(self, state_dict, strict=True):
-        self.module.load_state_dict(state_dict, strict=strict)
-
-# TODO:  Update overflow check + downscale to use Carl's fused kernel.
-class FP16_Optimizer(object):
-    """
-    :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, 
-    and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
-    For standard use, only two lines must be changed:  creating the :class:`FP16_Optimizer` instance,
-    and changing the call to ``backward``.
-
-    Example::
-
-        model = torch.nn.Linear(D_in, D_out).cuda().half()
-        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
-        # Name the FP16_Optimizer instance to replace the existing optimizer
-        # (recommended but not required):
-        optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
-        ...
-        # loss.backward() becomes:
-        optimizer.backward(loss)
-        ...
-
-    Example with dynamic loss scaling::
-
-        ...
-        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
-                                   # optional arg to control dynamic loss scaling behavior
-                                   # dynamic_loss_args={'scale_window' : 500})
-                                   # Usually, dynamic_loss_args is not necessary. 
-
-    Args:
-        init_optimizer (torch.optim.optimizer):  Existing optimizer created with the parameters to optimize.  Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones.  :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.  
-        static_loss_scale (float, optional, default=1.0):  Loss scale used internally to scale gradients computed by the model.  Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
-        dynamic_loss_scale (bool, optional, default=False):  Use dynamic loss scaling.  If True, this will override any ``static_loss_scale`` option.
-        dynamic_loss_args (dict, optional, default=None):  Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor.  Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor.  If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used.
-        verbose (bool, optional, default=True):  By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check.  If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``.  ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
-
-    ``init_optimizer`` is expected to have been constructed in the ordinary way.  
-    It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be 
-    named to replace ``init_optimizer``, for two reasons:  
-    First, it means that references to the same name
-    later in the file will not have to change.  
-    Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to 
-    modify ``init_optimizer``.  If you do choose a unique name for the new
-    :class:`FP16_Optimizer` instance, you should only work with this new instance,
-    because the preexisting optimizer might no longer behave as expected.
-
-    ``init_optimizer`` may be any Pytorch optimizer. 
-    It may contain a mixture of fp16 and fp32 parameters organized into any number of 
-    ``param_groups`` with different hyperparameters.  The :class:`FP16_Optimizer` constructor will 
-    ingest these ``param_groups`` and remember them. 
-
-    Calls to ::
-
-        loss.backward() 
-
-    must be replaced with ::
-
-        optimizer.backward(loss)  
-
-    because :class:`FP16_Optimizer` requires ownership of the backward pass to implement 
-    loss scaling and copies to master gradients.
-
-    .. note::
-        Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
-        are downscaled before being applied.  This means that adjusting the loss scale, or using
-        dynamic loss scaling, should not require retuning the learning rate or any other 
-        hyperparameters.
-
-
-    **Advanced options**
-
-    **Closures**:  :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
-    See docstring for :attr:`step`.
-
-    **Gradient clipping**:  Use :attr:`clip_master_grads`.
-    
-    **Multiple losses**:  If your model accumulates gradients from multiple losses,
-    this can be made more efficient by supplying ``update_master_grads=False``
-    to :attr:`backward`.  See docstring for :attr:`backward`.
-
-    **Manually adjusting loss scale**:  The current loss scale can be retrieved or set via ::
-
-        print(optimizer.loss_scale)
-        optimizer.loss_scale = new_loss_scale
-
-    For static loss scaling, manually adjusting the loss scale over time is a reasonable
-    thing to do.  During later epochs, gradients may become smaller, and a 
-    higher loss scale may be required, analogous to scheduling the learning rate.  Dynamic loss
-    scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting 
-    the loss scale is not recommended.
-
-    **Multi_GPU training**:  If the wrapped ``init_optimizer`` was created from a model wrapped in
-    Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` 
-    should still work as intended.
-    """
-
-    def __init__(self, 
-                 init_optimizer, 
-                 static_loss_scale=1.0, 
-                 dynamic_loss_scale=False,
-                 dynamic_loss_args=None,
-                 verbose=False):
-        if not torch.cuda.is_available:
-            raise SystemError("Cannot use fp16 without CUDA.")
-
-        self.verbose = verbose
-
-        self.optimizer = init_optimizer
-        # init_state_dict sets up an alternative way to cast per-param state tensors.
-        # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
-        # init_state_dict = init_optimizer.state_dict()
-
-        self.fp16_groups = []
-        self.fp32_from_fp16_groups = []
-        self.fp32_from_fp32_groups = []
-        for i, param_group in enumerate(self.optimizer.param_groups):
-            self.maybe_print("FP16_Optimizer processing param group {}:".format(i))
-            fp16_params_this_group = []
-            fp32_params_this_group = []
-            fp32_from_fp16_params_this_group = []
-            for i, param in enumerate(param_group['params']):
-                if param.requires_grad:
-                    if param.type() == 'torch.cuda.HalfTensor':
-                        self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
-                                         .format(param.size()))
-                        fp16_params_this_group.append(param)
-                        master_param = param.detach().clone().float()
-                        master_param.requires_grad = True
-                        # Copythe model parallel flag.
-                        master_param.model_parallel = param.model_parallel
-                        param_group['params'][i] = master_param
-                        fp32_from_fp16_params_this_group.append(master_param)
-                        # Reset existing state dict key to the new master param.
-                        # We still need to recast per-param state tensors, if any, to FP32.
-                        if param in self.optimizer.state:
-                           self.optimizer.state[master_param] = self.optimizer.state.pop(param) 
-                    elif param.type() == 'torch.cuda.FloatTensor':
-                        self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
-                                         .format(param.size()))
-                        fp32_params_this_group.append(param)
-                        param_group['params'][i] = param
-                    else:
-                        raise TypeError("Wrapped parameters must be either "
-                                        "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "  
-                                        "Received {}".format(param.type()))
-            
-            self.fp16_groups.append(fp16_params_this_group)
-            self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
-            self.fp32_from_fp32_groups.append(fp32_params_this_group)
-
-        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
-        self.optimizer.load_state_dict(self.optimizer.state_dict())
-        # alternative way to cast per-param state tensors:
-        # self.optimizer.load_state_dict(init_state_dict)
-
-        if dynamic_loss_scale:
-            self.dynamic_loss_scale = True
-            if dynamic_loss_args is not None:
-                self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
-            else:
-                self.loss_scaler = DynamicLossScaler()
-        else:
-            self.dynamic_loss_scale = False
-            self.loss_scaler = LossScaler(static_loss_scale)
-
-        self.overflow = False
-        self.first_closure_call_this_step = True
-
-        self.clip_grad_norm = clip_grad_norm
-
-    def maybe_print(self, msg):
-        if self.verbose:
-            print(msg)
-            
-    def __getstate__(self):
-        raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
-
-    def __setstate__(self, state):
-        raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
-
-    def zero_grad(self, set_grads_to_None=False):
-        """
-        Zero fp32 and fp16 parameter grads.
-        """
-        # In principle, only the .grad attributes of the model params need to be zeroed,
-        # because gradients are copied into the FP32 master params.  However, we zero
-        # all gradients owned by the optimizer, just to be safe:
-        for group in self.optimizer.param_groups:
-             for p in group['params']:
-                 if set_grads_to_None:
-                     p.grad = None
-                 else:
-                     if p.grad is not None:
-                         p.grad.detach_()
-                         p.grad.zero_()
-
-        # Zero fp16 gradients owned by the model:
-        for fp16_group in self.fp16_groups:
-            for param in fp16_group:
-                if set_grads_to_None:
-                    param.grad = None
-                else:
-                    if param.grad is not None:
-                        param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
-                        param.grad.zero_()
-
-    def _check_overflow(self):
-        params = [] 
-        for group in self.fp16_groups:
-            for param in group:
-                params.append(param)
-        for group in self.fp32_from_fp32_groups:
-            for param in group:
-                params.append(param)
-        self.overflow = self.loss_scaler.has_overflow(params)
-
-    def _update_scale(self, has_overflow=False):
-        self.loss_scaler.update_scale(has_overflow)
-
-    def _master_params_to_model_params(self):
-        for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
-            master_params_to_model_params(fp16_group, fp32_from_fp16_group)
-
-    def _model_params_to_master_params(self):
-        for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
-            master_params_to_model_params(fp32_from_fp16_group, fp16_group)
-
-    # To consider:  Integrate distributed with this wrapper by registering a hook on each variable 
-    # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
-    def _model_grads_to_master_grads(self):
-        for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
-            model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
-
-    def _downscale_master(self):
-        if self.loss_scale != 1.0:
-            for group in self.optimizer.param_groups:
-                for param in group['params']:
-                    if param.grad is not None:
-                        param.grad.data.mul_(1./self.loss_scale)
-
-    def clip_master_grads(self, max_norm, norm_type=2):
-        """
-        Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
-
-        Args:
-            max_norm (float or int): max norm of the gradients
-            norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
-                infinity norm.
-
-        Returns:
-            Total norm of the current fp32 gradients (viewed as a single vector).
-
-        .. warning::
-            Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).
-        """
-        if not self.overflow:
-            fp32_params = []
-            for param_group in self.optimizer.param_groups:
-                for param in param_group['params']:
-                    fp32_params.append(param)
-            return self.clip_grad_norm(fp32_params, max_norm, norm_type)
-        else:
-            return -1
-
-    def state_dict(self):
-        """
-        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
-        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
-        of the contained Pytorch optimizer.
-        Example::
-
-            checkpoint = {}
-            checkpoint['model'] = model.state_dict()
-            checkpoint['optimizer'] = optimizer.state_dict()
-            torch.save(checkpoint, "saved.pth")
-        """
-        state_dict = {}
-        state_dict['loss_scaler'] = self.loss_scaler
-        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
-        state_dict['overflow'] = self.overflow
-        state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
-        state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
-        state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
-        return state_dict
-
-    def load_state_dict(self, state_dict):
-        """
-        Loads a state_dict created by an earlier call to state_dict(). 
-        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, 
-        whose parameters in turn came from ``model``, it is expected that the user 
-        will call ``model.load_state_dict()`` before
-        ``fp16_optimizer_instance.load_state_dict()`` is called.
-
-        Example::
-
-            model = torch.nn.Linear(D_in, D_out).cuda().half()
-            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
-            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
-            ...
-            checkpoint = torch.load("saved.pth")
-            model.load_state_dict(checkpoint['model'])
-            optimizer.load_state_dict(checkpoint['optimizer'])
-        """
-        # I think it should actually be ok to reload the optimizer before the model.
-        self.loss_scaler = state_dict['loss_scaler']
-        self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
-        self.overflow = state_dict['overflow']
-        self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
-        self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
-        # At this point, the optimizer's references to the model's fp32 parameters are up to date.
-        # The optimizer's hyperparameters and internal buffers are also up to date.  
-        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
-        # out of date.  There are two options.  
-        # 1:  Refresh the master params from the model's fp16 params.  
-        # This requires less storage but incurs precision loss.
-        # 2:  Save and restore the fp32 master copies separately.
-        # We choose option 2.
-        # 
-        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device 
-        # of their associated parameters, because it's possible those buffers might not exist yet in 
-        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been 
-        # constructed in the same way as the one whose state_dict we are loading, the same master params
-        # are guaranteed to exist, so we can just copy_() from the saved master params.
-        for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
-            for current, saved in zip(current_group, saved_group):
-                current.data.copy_(saved.data)
-
-    def step(self, closure=None): # could add clip option.
-        """
-        If no closure is supplied, :attr:`step` should be called after 
-        ``fp16_optimizer_obj.backward(loss)``.
-        :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
-        :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
-        originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
-        another forward pass using their model.
-
-        If a closure is supplied, :attr:`step` may be called without a prior call to 
-        :attr:`backward(loss)`.
-        This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
-        However, the user should take care that any ``loss.backward()`` call within the closure
-        has been replaced by ``fp16_optimizer_obj.backward(loss)``.
-
-        Args:
-           closure (optional):  Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor.  closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.
-
-        Example with closure::
-
-            # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an 
-            # existing pytorch optimizer.
-            for input, target in dataset:
-                def closure():
-                    optimizer.zero_grad()
-                    output = model(input)
-                    loss = loss_fn(output, target)
-                    # loss.backward() becomes:
-                    optimizer.backward(loss)
-                    return loss
-                optimizer.step(closure)
-
-        .. warning::
-            Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.
-
-        .. _`ordinary Pytorch optimizer use`:
-            http://pytorch.org/docs/master/optim.html#optimizer-step-closure
-        """
-
-        scale = self.loss_scaler.loss_scale
-        self._update_scale(self.overflow)
-
-        if self.overflow:
-            self.maybe_print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}"
-                .format(scale, self.loss_scale))
-            return
-        
-        if closure is not None:
-            retval = self._step_with_closure(closure)
-        else:
-            retval = self.optimizer.step()
-
-        self._master_params_to_model_params()
-
-        return retval
-
-    def _step_with_closure(self, closure):
-        def wrapped_closure():
-            # helpful for debugging
-            # print("Calling wrapped_closure, first_closure_call_this_step = {}"
-            #       .format(self.first_closure_call_this_step))
-            if self.first_closure_call_this_step:
-                # We expect that the fp16 params are initially fresh on entering self.step(),
-                # so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
-                # is called within self.optimizer.step().
-                self.first_closure_call_this_step = False
-            else:
-                # If self.optimizer.step() internally calls wrapped_closure more than once,
-                # it may update the fp32 params after each call.  However, self.optimizer 
-                # doesn't know about the fp16 params at all.  If the fp32 params get updated,
-                # we can't rely on self.optimizer to refresh the fp16 params.  We need
-                # to handle that manually:
-                self._master_params_to_model_params()
-            # Our API expects the user to give us ownership of the backward() call by
-            # replacing all calls to loss.backward() with optimizer.backward(loss).
-            # This requirement holds whether or not the call to backward() is made within a closure.
-            # If the user is properly calling optimizer.backward(loss) within "closure," 
-            # calling closure() here will give the fp32 master params fresh gradients
-            # for the optimizer to play with, so all wrapped_closure needs to do is call 
-            # closure() and return the loss.
-            temp_loss = closure() 
-            while(self.overflow):
-                scale = self.loss_scaler.loss_scale
-                self._update_scale(self.overflow)
-                self.maybe_print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
-                      "reducing to {}".format(scale, self.loss_scale))
-                temp_loss = closure()
-            return temp_loss
-
-        retval = self.optimizer.step(wrapped_closure)
-
-        self.first_closure_call_this_step = True
-
-        return retval
-
-    def backward(self, loss, update_master_grads=True, retain_graph=False):
-        """ 
-        :attr:`backward` performs the following conceptual steps:
-
-        1. fp32_loss = loss.float() (see first Note below)
-        2. scaled_loss = fp32_loss*loss_scale
-        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).
-        4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.
-        5. Finally, master grads are divided by loss_scale.
-
-        In this way, after :attr:`backward`, the master params have fresh gradients,
-        and :attr:`step` may be called.
-
-        .. note::
-            :attr:`backward` internally converts the loss to fp32 before applying the loss scale.
-            This provides some additional safety against overflow if the user has supplied an 
-            fp16 loss value.  
-            However, for maximum overflow safety, the user should
-            compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to 
-            :attr:`backward`.
-
-        .. warning::
-            The gradients found in a model's leaves after the call to 
-            :attr:`backward` should not be regarded as valid in general, 
-            because it's possible 
-            they have been scaled (and in the case of dynamic loss scaling, 
-            the scale factor may change over time).  
-            If the user wants to inspect gradients after a call to :attr:`backward`,  
-            only the master gradients should be regarded as valid.  These can be retrieved via
-            :attr:`inspect_master_grad_data()`.
-
-        Args:
-            loss:  The loss output by the user's model.  loss may be either float or half (but see first Note above).
-            update_master_grads (bool, optional, default=True):  Option to copy fp16 grads to fp32 grads on this call.  By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration.  If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
-            retain_graph (bool, optional, default=False):  Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``.  If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).
-
-        Example::
-
-            # Ordinary operation:
-            optimizer.backward(loss)
-
-            # Naive operation with multiple losses (technically valid, but less efficient):
-            # fp32 grads will be correct after the second call,  but 
-            # the first call incurs an unnecessary fp16->fp32 grad copy.
-            optimizer.backward(loss1)
-            optimizer.backward(loss2)
-
-            # More efficient way to handle multiple losses:
-            # The fp16->fp32 grad copy is delayed until fp16 grads from all 
-            # losses have been accumulated.
-            optimizer.backward(loss1, update_master_grads=False)
-            optimizer.backward(loss2, update_master_grads=False)
-            optimizer.update_master_grads()
-        """ 
-        # To consider:  try multiple backward passes using retain_grad=True to find 
-        # a loss scale that works.  After you find a loss scale that works, do a final dummy
-        # backward pass with retain_graph=False to tear down the graph.  Doing this would avoid 
-        # discarding the iteration,  but probably wouldn't improve overall efficiency.  
-        self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
-        if update_master_grads:
-            self.update_master_grads()
-
-    def update_master_grads(self):
-        """
-        Copy the ``.grad`` attribute from stored references to fp16 parameters to 
-        the ``.grad`` attribute of the fp32 master parameters that are directly 
-        updated by the optimizer.  :attr:`update_master_grads` only needs to be called if
-        ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
-        """
-        if self.dynamic_loss_scale:
-            self._check_overflow()
-            if self.overflow: return
-        self._model_grads_to_master_grads()
-        self._downscale_master()
-
-    def inspect_master_grad_data(self):
-        """
-        When running with :class:`FP16_Optimizer`, 
-        ``.grad`` attributes of a model's fp16 leaves should not be
-        regarded as truthful, because they might be scaled.  
-        After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
-        the fp32 master params' ``.grad``
-        attributes will contain valid gradients properly divided by the loss scale.  However, 
-        because :class:`FP16_Optimizer` flattens some parameters, accessing them may be 
-        nonintuitive.  :attr:`inspect_master_grad_data`
-        allows those gradients to be viewed with shapes corresponding to their associated model leaves.
-
-        Returns:
-            List of lists (one list for each parameter group).  The list for each parameter group
-            is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.                 
-        """
-        if self.overflow:
-            print("Warning:  calling FP16_Optimizer.inspect_master_grad_data while in an overflow state.  "
-                  "Gradients are currently invalid (may be inf, nan, or stale).  Returning None.")
-            return None
-        else:
-            # The optimizer owns only references to master params.
-            master_grads_data = []
-            for param_group in self.optimizer.param_groups:
-                master_grads_this_group = []
-                for param in param_group['params']:
-                    if param.grad is not None:
-                        master_grads_this_group.append(param.grad.data)
-                    else:
-                        master_grads_this_group.append(None)
-                master_grads_data.append(master_grads_this_group)
-            return master_grads_data
-
-
-    # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
-    def _get_loss_scale(self):
-        return self.loss_scaler.loss_scale
-
-    def _set_loss_scale(self, value):
-        self.loss_scaler.cur_scale = value
-
-    loss_scale = property(_get_loss_scale, _set_loss_scale)
-
-    # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
-    def _get_state(self):
-        return self.optimizer.state
-
-    def _set_state(self, value):
-        self.optimizer.state = value
-
-    state = property(_get_state, _set_state)
-
-    # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
-    # (for example, to adjust the learning rate)
-    def _get_param_groups(self):
-        return self.optimizer.param_groups
-
-    def _set_param_groups(self, value):
-        self.optimizer.param_groups = value
-
-    param_groups = property(_get_param_groups, _set_param_groups)
diff --git a/fp16/fp16util.py b/fp16/fp16util.py
deleted file mode 100755
index 8572bb1b82ce6661b62d235df18ccf5479355a42..0000000000000000000000000000000000000000
--- a/fp16/fp16util.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-import torch.nn as nn
-from torch.autograd import Variable
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-
-import mpu
-
-
-class tofp16(nn.Module):
-    """
-    Utility module that implements::
-
-        def forward(self, input):
-            return input.half()
-    """
-
-    def __init__(self):
-        super(tofp16, self).__init__()
-
-    def forward(self, input):
-        return input.half()
-
-
-def BN_convert_float(module):
-    """
-    Utility function for network_to_half().
-
-    Retained for legacy purposes.
-    """
-    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
-        module.float()
-    for child in module.children():
-        BN_convert_float(child)
-    return module
-
-
-def network_to_half(network):
-    """
-    Convert model to half precision in a batchnorm-safe way.
-
-    Retained for legacy purposes. It is recommended to use FP16Model.
-    """
-    return nn.Sequential(tofp16(), BN_convert_float(network.half()))
-
-
-def convert_module(module, dtype):
-    """
-    Converts a module's immediate parameters and buffers to dtype.
-    """
-    for param in module.parameters(recurse=False):
-        if param is not None:
-            if param.data.dtype.is_floating_point:
-                param.data = param.data.to(dtype=dtype)
-            if param._grad is not None and param._grad.data.dtype.is_floating_point:
-                param._grad.data = param._grad.data.to(dtype=dtype)
-
-    for buf in module.buffers(recurse=False):
-        if buf is not None and buf.data.dtype.is_floating_point:
-            buf.data = buf.data.to(dtype=dtype)
-
-
-def convert_network(network, dtype):
-    """
-    Converts a network's parameters and buffers to dtype.
-    """
-    for module in network.modules():
-        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
-            continue
-        convert_module(module, dtype)
-    return network
-
-
-class FP16Model(nn.Module):
-    """
-    Convert model to half precision in a batchnorm-safe way.
-    """
-
-    def __init__(self, network):
-        super(FP16Model, self).__init__()
-        self.network = convert_network(network, dtype=torch.half)
-
-    def forward(self, *inputs):
-        inputs = tuple(t.half() for t in inputs)
-        return self.network(*inputs)
-
-
-def backwards_debug_hook(grad):
-    raise RuntimeError("master_params recieved a gradient in the backward pass!")
-
-def prep_param_lists(model, flat_master=False):
-    """
-    Creates a list of FP32 master parameters for a given model, as in
-    `Training Neural Networks with Mixed Precision:  Real Examples`_.
-
-    Args:
-        model (torch.nn.Module): Existing Pytorch model
-        flat_master (bool, optional, default=False):  Flatten the master parameters into a single tensor, as a performance optimization.
-    Returns:
-        A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`.  ``master_params`` is a list of FP32 master gradients.  If ``flat_master=True``, ``master_params`` will be a list with one element.
-
-    Example::
-
-        model_params, master_params = prep_param_lists(model)
-
-    .. warning::
-        Currently, if ``flat_master=True``, all the model's parameters must be the same type.  If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.
-
-    .. _`Training Neural Networks with Mixed Precision:  Real Examples`:
-        http://on-demand.gputechconf.com/gtc/2018/video/S81012/
-    """
-    model_params = [param for param in model.parameters() if param.requires_grad]
-
-    if flat_master:
-        # Give the user some more useful error messages
-        try:
-            # flatten_dense_tensors returns a contiguous flat array.
-            # http://pytorch.org/docs/master/_modules/torch/_utils.html
-            master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
-        except:
-            print("Error in prep_param_lists:  model may contain a mixture of parameters "
-                      "of different types.  Use flat_master=False, or use F16_Optimizer.")
-            raise
-        master_params = torch.nn.Parameter(master_params)
-        master_params.requires_grad = True
-        # master_params.register_hook(backwards_debug_hook)
-        if master_params.grad is None:
-            master_params.grad = master_params.new(*master_params.size())
-        return model_params, [master_params]
-    else:
-        master_params = [param.clone().float().detach() for param in model_params]
-        for param in master_params:
-            param.requires_grad = True
-        return model_params, master_params
-
-
-def model_grads_to_master_grads(model_params, master_params, flat_master=False):
-    """
-    Copy model gradients to master gradients.  
-
-    Args:
-        model_params:  List of model parameters created by :func:`prep_param_lists`.
-        master_params:  List of FP32 master parameters created by :func:`prep_param_lists`.  If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
-    """
-    if flat_master:
-        # The flattening may incur one more deep copy than is necessary.
-        master_params[0].grad.data.copy_(
-            _flatten_dense_tensors([p.grad.data for p in model_params]))
-    else:
-        for model, master in zip(model_params, master_params):
-            if model.grad is not None:
-                if master.grad is None:
-                    master.grad = Variable(master.data.new(*master.data.size()))
-                master.grad.data.copy_(model.grad.data)
-            else:
-                master.grad = None
-
-
-def master_params_to_model_params(model_params, master_params, flat_master=False):
-    """
-    Copy master parameters to model parameters.
-
-    Args:
-        model_params:  List of model parameters created by :func:`prep_param_lists`.
-        master_params:  List of FP32 master parameters created by :func:`prep_param_lists`.  If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
-    """
-    if flat_master:
-        for model, master in zip(model_params, 
-                                 _unflatten_dense_tensors(master_params[0].data, model_params)):
-            model.data.copy_(master)
-    else:
-        for model, master in zip(model_params, master_params):
-            model.data.copy_(master.data)
-
-# Backward compatibility fixes
-
-def to_python_float(t):
-    if hasattr(t, 'item'):
-        return t.item()
-    else:
-        return t[0]
-
-TORCH_MAJOR = int(torch.__version__.split('.')[0])
-TORCH_MINOR = int(torch.__version__.split('.')[1])
-
-clip_grad_norm = mpu.clip_grad_norm
-#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
-#    clip_grad_norm = torch.nn.utils.clip_grad_norm
-#else:
-#    clip_grad_norm = torch.nn.utils.clip_grad_norm_
diff --git a/fp16/loss_scaler.py b/fp16/loss_scaler.py
deleted file mode 100755
index 4426bf4a9e99dfd49cd53b67e1201a7a27052bbe..0000000000000000000000000000000000000000
--- a/fp16/loss_scaler.py
+++ /dev/null
@@ -1,237 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-import mpu
-
-# item() is a recent addition, so this helps with backward compatibility.
-def to_python_float(t):
-    if hasattr(t, 'item'):
-        return t.item()
-    else:
-        return t[0]
-
-class LossScaler:
-    """
-    Class that manages a static loss scale.  This class is intended to interact with
-    :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
-
-    Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to 
-    :class:`FP16_Optimizer`'s constructor.
-
-    Args:
-        scale (float, optional, default=1.0):  The loss scale.
-    """
-
-    def __init__(self, scale=1):
-        self.cur_scale = scale
-
-    # `params` is a list / generator of torch.Variable
-    def has_overflow(self, params):
-        return False
-
-    # `x` is a torch.Tensor
-    def _has_inf_or_nan(x):
-        return False
-
-    def update_scale(self, overflow):
-        pass
-
-    @property
-    def loss_scale(self):
-        return self.cur_scale
-
-    def scale_gradient(self, module, grad_in, grad_out):
-        return tuple(self.loss_scale * g for g in grad_in)
-
-    def backward(self, loss, retain_graph=False):
-        scaled_loss = loss*self.loss_scale
-        scaled_loss.backward(retain_graph=retain_graph)
-
-class DynamicLossScaler:
-    """
-    Class that manages dynamic loss scaling.  It is recommended to use :class:`DynamicLossScaler`
-    indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of 
-    :class:`FP16_Optimizer`.  However, it's important to understand how :class:`DynamicLossScaler`
-    operates, because the default options can be changed using the
-    the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
-
-    Loss scaling is designed to combat the problem of underflowing gradients encountered at long
-    times when training fp16 networks.  Dynamic loss scaling begins by attempting a very high loss
-    scale.  Ironically, this may result in OVERflowing gradients.  If overflowing gradients are
-    encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has 
-    occurred.
-    :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
-    and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.  
-    If a certain number of iterations occur without overflowing gradients detected,
-    :class:`DynamicLossScaler` increases the loss scale once more.
-    In this way :class:`DynamicLossScaler` attempts to "ride the edge" of 
-    always using the highest loss scale possible without incurring overflow.
-
-    Args:
-        init_scale (float, optional, default=2**32):  Initial loss scale attempted by :class:`DynamicLossScaler.`
-        scale_factor (float, optional, default=2.0):  Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``.  If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. 
-        scale_window (int, optional, default=1000):  Number of consecutive iterations without an overflow to wait before increasing the loss scale.
-    """
-
-    def __init__(self,
-                 init_scale=2**32,
-                 scale_factor=2.,
-                 scale_window=1000,
-                 min_scale=1,
-                 delayed_shift=1,
-                 consecutive_hysteresis=False):
-        self.cur_scale = init_scale
-        self.cur_iter = 0
-        self.last_overflow_iter = -1
-        self.scale_factor = scale_factor
-        self.scale_window = scale_window
-        self.min_scale = min_scale
-        self.delayed_shift = delayed_shift
-        self.cur_hysteresis = delayed_shift
-        self.consecutive_hysteresis = consecutive_hysteresis
-
-    # `params` is a list / generator of torch.Variable
-    def has_overflow_serial(self, params):
-        for p in params:
-            if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
-                return True
-
-        return False
-
-    def has_overflow(self, params):
-        overflow = self.has_overflow_serial(params)
-        # Since each model parallel GPU carries only part of the model,
-        # make sure overflow flag is synced across all the model parallel GPUs
-        overflow_gpu = torch.cuda.ByteTensor([overflow])
-        torch.distributed.all_reduce(overflow_gpu,
-                                     op=torch.distributed.ReduceOp.MAX,
-                                     group=mpu.get_model_parallel_group())
-        overflow = overflow_gpu[0].item()
-        return bool(overflow)
-
-
-    # `x` is a torch.Tensor
-    def _has_inf_or_nan(x):
-        try:
-            # if x is half, the .float() incurs an additional deep copy, but it's necessary if 
-            # Pytorch's .sum() creates a one-element tensor of the same type as x 
-            # (which is true for some recent version of pytorch).
-            cpu_sum = float(x.float().sum())
-            # More efficient version that can be used if .sum() returns a Python scalar
-            # cpu_sum = float(x.sum())
-        except RuntimeError as instance:
-            # We want to check if inst is actually an overflow exception.
-            # RuntimeError could come from a different error.
-            # If so, we still want the exception to propagate.
-            if "value cannot be converted" not in instance.args[0]:
-                raise
-            return True
-        else:
-            if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
-                return True
-            return False
-
-    # `overflow` is boolean indicating whether the gradient overflowed
-    def update_scale(self, overflow):
-
-        if not hasattr(self, 'min_scale'):
-            self.min_scale = 1
-        if not hasattr(self, 'delayed_shift'):
-            self.delayed_shift = 1
-        if not hasattr(self, 'cur_hysteresis'):
-            self.cur_hysteresis = 1
-        if not hasattr(self, 'consecutive_hysteresis'):
-            self.consecutive_hysteresis = True
-        if overflow:
-            # self.cur_scale /= self.scale_factor
-            if self.delayed_shift == 1 or self.cur_hysteresis == 1:
-                self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale)
-            else:
-                self.cur_hysteresis -= 1
-            self.last_overflow_iter = self.cur_iter
-        else:
-            if self.consecutive_hysteresis:
-                self.cur_hysteresis = self.delayed_shift
-            if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
-                if not self.consecutive_hysteresis:
-                    self.cur_hysteresis = self.delayed_shift
-                self.cur_scale *= self.scale_factor
-        self.cur_iter += 1
-
-    @property
-    def loss_scale(self):
-        return self.cur_scale
-
-    def scale_gradient(self, module, grad_in, grad_out):
-        return tuple(self.loss_scale * g for g in grad_in)
-
-    def backward(self, loss, retain_graph=False):
-        scaled_loss = loss*self.loss_scale
-        scaled_loss.backward(retain_graph=retain_graph)
-        
-##############################################################        
-# Example usage below here -- assuming it's in a separate file
-##############################################################
-"""
-TO-DO separate out into an example.
-if __name__ == "__main__":
-    import torch
-    from torch.autograd import Variable
-    from dynamic_loss_scaler import DynamicLossScaler
-
-    # N is batch size; D_in is input dimension;
-    # H is hidden dimension; D_out is output dimension.
-    N, D_in, H, D_out = 64, 1000, 100, 10
-
-    # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
-    x = Variable(torch.randn(N, D_in), requires_grad=False)
-    y = Variable(torch.randn(N, D_out), requires_grad=False)
-
-    w1 = Variable(torch.randn(D_in, H), requires_grad=True)
-    w2 = Variable(torch.randn(H, D_out), requires_grad=True)
-    parameters = [w1, w2]
-
-    learning_rate = 1e-6
-    optimizer = torch.optim.SGD(parameters, lr=learning_rate)
-    loss_scaler = DynamicLossScaler()
-
-    for t in range(500):
-        y_pred = x.mm(w1).clamp(min=0).mm(w2)
-        loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
-        print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
-        print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
-        print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
-
-        # Run backprop
-        optimizer.zero_grad()
-        loss.backward()
-        
-        # Check for overflow
-        has_overflow = DynamicLossScaler.has_overflow(parameters)
-        
-        # If no overflow, unscale grad and update as usual
-        if not has_overflow:
-            for param in parameters:
-                param.grad.data.mul_(1. / loss_scaler.loss_scale)
-            optimizer.step()
-        # Otherwise, don't do anything -- ie, skip iteration
-        else:
-            print('OVERFLOW!')
-
-        # Update loss scale for next iteration
-        loss_scaler.update_scale(has_overflow)
-
-"""
diff --git a/generate_samples.py b/generate_samples.py
index f0c5ba54e71a87d8443e08c342714345f1bb000d..bbca2643061e369aa5cc26730f6c2e58b8274f7c 100755
--- a/generate_samples.py
+++ b/generate_samples.py
@@ -33,9 +33,7 @@ from data_utils import get_tokenizer
 import mpu
 import deepspeed
 
-from fp16 import FP16_Module
 from model import GPT2Model
-from model import DistributedDataParallel as DDP
 from utils import print_rank_0
 from pretrain_gpt2 import get_model
 import math
diff --git a/model/__init__.py b/model/__init__.py
index 4f3057c51f7f4b32eeff224413c2cf88da45e008..beb12711b72f1688caf8304000573b1289081fa8 100755
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .distributed import *
 from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization
 from .gpt2_modeling import GPT2Model
 
diff --git a/model/base_model.py b/model/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..85c416e848be01fd7aef23f4dc09d2463d66d101
--- /dev/null
+++ b/model/base_model.py
@@ -0,0 +1,54 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   base_model.py
+@Time    :   2021/10/01 22:40:33
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+from functools import partial
+
+from mpu import BaseTransformer
+
+class BaseModel(torch.nn.Module):
+    def __init__(self, args, transformer=None):
+        super(BaseModel, self).__init__()
+        self.hooks = self.collect_hooks()
+        if transformer is not None:
+            self.transformer = transformer
+        else:
+            self.transformer = BaseTransformer(
+                num_layers=args.num_layers,
+                vocab_size=args.vocab_size,
+                hidden_size=args.hidden_size,
+                num_attention_heads=args.num_attention_heads,
+                max_sequence_length=args.max_position_embeddings,
+                embedding_dropout_prob=args.hidden_dropout,
+                attention_dropout_prob=args.attention_dropout,
+                output_dropout_prob=args.hidden_dropout,
+                checkpoint_activations=args.checkpoint_activations,
+                checkpoint_num_layers=args.checkpoint_num_layers,
+                sandwich_ln=args.sandwich_ln,
+                parallel_output=True,
+                hooks=self.hooks
+            )
+        self.mixins = torch.nn.ModuleList()
+        
+    def reinit(self):
+        for m in self.mixins:
+            m.reinit(self.transformer)
+    
+    def collect_hooks(self):
+        names = ['word_embedding_forward', 'position_embedding_forward',
+                    'attention_forward', 'mlp_forward', 'final_forward']
+        hooks = {}
+        for name in names:
+            if hasattr(self, name):
+                hooks[name] = partial(getattr(self, name), self)
+        return hooks
\ No newline at end of file
diff --git a/model/distributed.py b/model/distributed.py
deleted file mode 100755
index c5c54997e0b289747d47b0396df4d93060445538..0000000000000000000000000000000000000000
--- a/model/distributed.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-import torch.distributed as dist
-from torch.nn.modules import Module
-from torch.autograd import Variable
-from torch.nn.parallel.distributed import DistributedDataParallel as DDP
-
-import mpu
-
-
-class PyTorchDistributedDataParallel(DDP):
-    def state_dict(self, destination=None, prefix='', keep_vars=False):
-        sd = self.module.state_dict(destination, prefix, keep_vars)
-        return sd
-
-    def load_state_dict(self, state_dict, strict=True):
-        self.module.load_state_dict(state_dict, strict=strict)
-
-
-class DistributedDataParallel(Module):
-
-    def __init__(self, module):
-        super(DistributedDataParallel, self).__init__()
-        self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
-
-        self.module = module
-        self.data_parallel_group = mpu.get_data_parallel_group()
-        src_rank = mpu.get_model_parallel_rank()
-        for p in self.module.parameters():
-            if torch.is_tensor(p):
-                dist.broadcast(p, src_rank, group=self.data_parallel_group)
-
-        def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
-            if(self.needs_reduction):
-                self.needs_reduction = False
-                buckets = {}
-                for name, param in self.module.named_parameters():
-                    if param.requires_grad and param.grad is not None:
-                        tp = (param.data.type())
-                        if tp not in buckets:
-                            buckets[tp] = []
-                        buckets[tp].append(param)
-                if self.warn_on_half:
-                    if torch.cuda.HalfTensor in buckets:
-                        print("WARNING: gloo dist backend for half parameters may be extremely slow." +
-                              " It is recommended to use the NCCL backend in this case.")
-                        self.warn_on_half = False
-                for tp in buckets:
-                    bucket = buckets[tp]
-                    grads = [param.grad.data for param in bucket]
-                    coalesced = _flatten_dense_tensors(grads)
-                    if fp32_allreduce:
-                        coalesced = coalesced.float()
-                    if not no_scale and not reduce_after:
-                        coalesced /= dist.get_world_size(group=self.data_parallel_group)
-                    dist.all_reduce(coalesced, group=self.data_parallel_group)
-                    torch.cuda.synchronize()
-                    if not no_scale and reduce_after:
-                        coalesced /= dist.get_world_size(group=self.data_parallel_group)
-                    for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
-                        buf.copy_(synced)
-        self.hook_handles = []
-        self.hooks = []
-        for param in list(self.module.parameters()):
-            def allreduce_hook(*unused):
-                Variable._execution_engine.queue_callback(allreduce_params)
-        #    handle = param.register_hook(allreduce_hook)
-            #self.hooks.append(allreduce_hook)
-            #self.hook_handles.append(handle)
-        self.allreduce_params = allreduce_params
-
-    def forward(self, *inputs, **kwargs):
-        self.needs_reduction = True
-        return self.module(*inputs, **kwargs)
-
-    def state_dict(self, destination=None, prefix='', keep_vars=False):
-        #[h.remove() for h in self.hook_handles]
-        sd = self.module.state_dict(destination, prefix, keep_vars)
-       # for handle, hook in zip(self.hook_handles, self.hooks):
-       #     d = handle.hooks_dict_ref()
-       #     d[handle.id] = hook
-
-        return sd
-
-    def load_state_dict(self, state_dict, strict=True):
-        self.module.load_state_dict(state_dict, strict=strict)
-
-    '''
-    def _sync_buffers(self):
-        buffers = list(self.module._all_buffers())
-        if len(buffers) > 0:
-            # cross-node buffer sync
-            flat_buffers = _flatten_dense_tensors(buffers)
-            dist.broadcast(flat_buffers, 0)
-            for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
-                buf.copy_(synced)
-    def train(self, mode=True):
-        # Clear NCCL communicator and CUDA event cache of the default group ID,
-        # These cache will be recreated at the later call. This is currently a
-        # work-around for a potential NCCL deadlock.
-        if dist._backend == dist.dist_backend.NCCL:
-            dist._clear_group_cache()
-        super(DistributedDataParallel, self).train(mode)
-        self.module.train(mode)
-    '''
-
diff --git a/model/gpt2.py b/model/gpt2.py
new file mode 100755
index 0000000000000000000000000000000000000000..514f8223ae6cfed27d4aa86c5151fa26a16fa228
--- /dev/null
+++ b/model/gpt2.py
@@ -0,0 +1,17 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   gpt2_modeling.py
+@Time    :   2021/10/02 00:37:22
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+
+from .base_model import BaseModel
+
diff --git a/model/gpt2_modeling.py b/model/gpt2_modeling.py
deleted file mode 100755
index c854f1bce00f0a0397ebb796c1f7f17f7e3c1b4b..0000000000000000000000000000000000000000
--- a/model/gpt2_modeling.py
+++ /dev/null
@@ -1,126 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""GPT-2 model."""
-
-import torch
-import torch.nn.functional as F
-import argparse
-
-import mpu
-
-
-def init_method_normal(std=0.02):
-    """Init method based on normal distribution.
-
-    This is only used for embeddings. The transformer has its
-    own initializer.
-    """
-    def init_(tensor):
-        return torch.nn.init.normal_(tensor, mean=0.0, std=std)
-    return init_
-
-
-def gpt2_get_params_for_weight_decay_optimization(module):
-
-    weight_decay_params = {'params': []}
-    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
-    for module_ in module.modules():
-        if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
-            no_weight_decay_params['params'].extend(
-                [p for p in list(module_._parameters.values())
-                 if p is not None and p.requires_grad])
-        else:
-            weight_decay_params['params'].extend(
-                [p for n, p in list(module_._parameters.items())
-                 if p is not None and n != 'bias' and p.requires_grad])
-            no_weight_decay_params['params'].extend(
-                [p for n, p in list(module_._parameters.items())
-                 if p is not None and n == 'bias' and p.requires_grad])
-    return weight_decay_params, no_weight_decay_params
-
-
-class GPT2Model(torch.nn.Module):
-    """GPT-2 Language model.
-
-    The output of the forward method are the logits (parallel or
-    serial depending on the `parallel_output` flag.
-    """
-
-    def __init__(self,
-                 num_layers,
-                 vocab_size,
-                 hidden_size,
-                 num_attention_heads,
-                 embedding_dropout_prob,
-                 attention_dropout_prob,
-                 output_dropout_prob,
-                 max_sequence_length,
-                 max_memory_length,
-                 checkpoint_activations,
-                 sandwich_ln,
-                 checkpoint_num_layers=1,
-                 parallel_output=True,
-                 sparse_config=argparse.Namespace(sparse_type='standard'),
-                 finetune=False
-                 ):
-
-        super(GPT2Model, self).__init__()
-
-        self.parallel_output = parallel_output
-
-        init_method = init_method_normal(std=0.02)
-
-        # Word embeddings (parallel).
-        self.word_embeddings = mpu.VocabParallelEmbedding(
-            vocab_size, hidden_size, init_method=init_method)
-
-        # Transformer
-        self.transformer = mpu.GPT2ParallelTransformer(num_layers,
-                                                       hidden_size,
-                                                       num_attention_heads,
-                                                       max_sequence_length,
-                                                       max_memory_length,
-                                                       embedding_dropout_prob,
-                                                       attention_dropout_prob,
-                                                       output_dropout_prob,
-                                                       checkpoint_activations,
-                                                       checkpoint_num_layers,
-                                                       sandwich_ln=sandwich_ln,
-                                                       sparse_config=sparse_config,
-                                                       finetune=finetune
-                                                       )
-
-    def forward(self, input_ids, position_ids, attention_mask, *mems):
-        # Embeddings.
-        words_embeddings = self.word_embeddings(input_ids)
-        embeddings = words_embeddings
-
-        # Transformer.
-        transformer_output = self.transformer(embeddings, position_ids, attention_mask, *mems)
-        logits, *hidden_layers = transformer_output
-        # Parallel logits.
-        logits_parallel = mpu.copy_to_model_parallel_region(
-            logits)
-        logits_parallel = F.linear(logits_parallel,
-                                   self.word_embeddings.weight)
-
-        if self.parallel_output:
-            return (logits_parallel, *hidden_layers)
-
-        return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)
-    
-    def init_plus_from_old(self):
-        self.transformer.init_plus_from_old()
diff --git a/model/mixins.py b/model/mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e195e3c4e37b0ce4982974d17eff56129941154
--- /dev/null
+++ b/model/mixins.py
@@ -0,0 +1,69 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   mixins.py
+@Time    :   2021/10/01 17:52:40
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+import torch
+from mpu import ColumnParallelLinear, RowParallelLinear
+from mpu.transformer import unscaled_init_method
+
+class BaseMixin(torch.nn.Module):
+    def __init__(self):
+        super(BaseMixin, self).__init__()
+        # define new params
+    def reinit(self, transformer, *pre_mixins):
+        # reload the initial params from previous trained modules
+        pass
+
+class PositionEmbeddingMixin(BaseMixin):
+    def __init__(self, new_sequence_length, hidden_size, 
+                init_method_std=0.02, reinit_slice=(-1024, None)
+        ):
+        super(PositionEmbeddingMixin, self).__init__()
+        self.reinit_slice = reinit_slice
+        self.position_embeddings = torch.nn.Embedding(new_sequence_length, hidden_size)
+        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
+    def reinit(self, transformer, *pre_mixins):
+        old_weights = transformer.position_embeddings.weight.data[self.reinit_slice]
+        old_len, hidden_size = old_weights.shape[0]
+        assert hidden_size == self.position_embeddings.weight.shape[-1]
+        self.position_embeddings_plus.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
+
+class AttentionMixin(BaseMixin):
+    def __init__(self, num_layers,
+                hidden_size, 
+                init_method=unscaled_init_method(0.02),
+                output_layer_init_method=unscaled_init_method(0.02)
+        ):
+        super(AttentionMixin, self).__init__()
+        self.num_layers = num_layers # replace attention in the LAST n layers
+        self.query_key_value = torch.nn.ModuleList(
+            [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
+                gather_output=False,init_method=init_method)
+                for layer_id in range(num_layers)
+            ])
+        self.dense = torch.nn.ModuleList(
+            [RowParallelLinear(hidden_size,
+                hidden_size,
+                input_is_parallel=True,
+                init_method=output_layer_init_method)
+                for layer_id in range(num_layers)
+            ])
+    def reinit(self, transformer, *pre_mixins):
+        start_layer = len(transformer.layers) - self.num_layers
+        assert start_layer >= 0
+        for layer_id in range(self.num_layers):
+            old_attention = transformer.layers[start_layer + layer_id].attention
+            self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
+            self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
+            self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
+            self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
diff --git a/mpu/__init__.py b/mpu/__init__.py
index 25f39b459882fbc7fa0ac2d544cdefe10e148f9f..358f99703eca297a4a8d06fbddbd9e2eb089f559 100755
--- a/mpu/__init__.py
+++ b/mpu/__init__.py
@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy
 
 from .data import broadcast_data
 
-from .grads import clip_grad_norm
-
 from .initialize import destroy_model_parallel
 from .initialize import get_data_parallel_group
 from .initialize import get_data_parallel_rank
@@ -47,5 +45,5 @@ from .random import partition_activations_in_checkpoint
 from .random import get_cuda_rng_tracker
 from .random import model_parallel_cuda_manual_seed
 
-from .sparse_transformer import GPT2ParallelTransformer
-from .sparse_transformer import LayerNorm
+from .transformer import BaseTransformer
+from .transformer import LayerNorm
diff --git a/mpu/grads.py b/mpu/grads.py
deleted file mode 100755
index 0ae9cf5d300f206e4e2d144bd6e4f0fb39f049c4..0000000000000000000000000000000000000000
--- a/mpu/grads.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-# Parts of the code here are adapted from PyTorch
-# repo: https://github.com/pytorch/pytorch
-
-
-import torch
-from torch._six import inf
-
-from .initialize import get_model_parallel_group
-from .initialize import get_model_parallel_rank
-
-
-def clip_grad_norm(parameters, max_norm, norm_type=2):
-    """Clips gradient norm of an iterable of parameters.
-
-    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
-    added functionality to handle model parallel parameters. Note that
-    the gradients are modified in place.
-
-    Arguments:
-        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
-            single Tensor that will have gradients normalized
-        max_norm (float or int): max norm of the gradients
-        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
-            infinity norm.
-
-    Returns:
-        Total norm of the parameters (viewed as a single vector).
-    """
-    if isinstance(parameters, torch.Tensor):
-        parameters = [parameters]
-    parameters = list(filter(lambda p: p.grad is not None, parameters))
-    max_norm = float(max_norm)
-    norm_type = float(norm_type)
-    if norm_type == inf:
-        total_norm = max(p.grad.data.abs().max() for p in parameters)
-        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
-        # Take max across all GPUs.
-        torch.distributed.all_reduce(total_norm_cuda,
-                                     op=torch.distributed.ReduceOp.MAX,
-                                     group=get_model_parallel_group())
-        total_norm = total_norm_cuda[0].item()
-    else:
-        total_norm = 0
-        for p in parameters:
-            if p.model_parallel or (get_model_parallel_rank() == 0):
-                param_norm = p.grad.data.norm(norm_type)
-                total_norm += param_norm.item() ** norm_type
-        # Sum across all model parallel GPUs.
-        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
-        torch.distributed.all_reduce(total_norm_cuda,
-                                     op=torch.distributed.ReduceOp.SUM,
-                                     group=get_model_parallel_group())
-        total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
-    clip_coef = max_norm / (total_norm + 1e-6)
-    if clip_coef < 1:
-        for p in parameters:
-            p.grad.data.mul_(clip_coef)
-    return total_norm
diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py
deleted file mode 100755
index 37f3f9dbaeacff4ddcf1c500c75613bee878add6..0000000000000000000000000000000000000000
--- a/mpu/sparse_transformer.py
+++ /dev/null
@@ -1,729 +0,0 @@
-# coding=utf-8
-# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Transformer."""
-
-import math
-import random
-import argparse
-
-import torch
-import torch.nn.init as init
-import torch.nn.functional as F
-from apex.normalization.fused_layer_norm import FusedLayerNorm
-
-from .initialize import get_model_parallel_world_size
-from .layers import ColumnParallelLinear
-from .layers import RowParallelLinear
-from .mappings import gather_from_model_parallel_region
-
-import deepspeed
-
-from .random import checkpoint
-from .random import get_cuda_rng_tracker
-
-from .utils import divide, sqrt
-from .utils import split_tensor_along_last_dim
-import torch.distributed as dist
-
-class LayerNorm(FusedLayerNorm):
-    def __init__(self, *args, pb_relax=False, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.pb_relax = pb_relax
-    def forward(self, x):
-        if not self.pb_relax:
-            return super().forward(x)
-        return super().forward(x / (x.abs().max().detach()/8))
-
-class GPT2ParallelSelfAttention(torch.nn.Module):
-    """Parallel self-attention layer for GPT2.
-
-    Self-attention layer takes input with size [b, s, h] where b is
-    the batch size, s is the sequence length, and h is the hidden size
-    and creates output of the same size.
-    Arguments:
-        hidden_size: total hidden size of the layer (h).
-        num_attention_heads: number of attention heads (n). Note that we
-                             require n to be divisible by number of GPUs
-                             used to parallelize the model. Also, we
-                             require hidden size to be divisible by n.
-        dropout_prob: dropout probability for the attention scores.
-        init_method: weight initialization.
-        output_layer_init_method: output layer initialization. If None, use
-                                  `init_method`.
-    We use the following notation:
-        h: hidden_size
-        n: num_attention_heads
-        p: number of partitions
-        np: n/p
-        hp: h/p
-        hn: h/n
-        b: batch size
-        s: sequence length
-    """
-    def __init__(self, hidden_size, num_attention_heads,
-                 attention_dropout_prob, output_dropout_prob,
-                 init_method, layer_id, output_layer_init_method=None,sparse_config=None,
-                 finetune=False):
-        super(GPT2ParallelSelfAttention, self).__init__()
-        # Set output layer initialization if not provided.
-        if output_layer_init_method is None:
-            output_layer_init_method = init_method
-        self.layer_id = layer_id
-        # Per attention head and per partition values.
-        world_size = get_model_parallel_world_size()
-        self.hidden_size_per_partition = divide(hidden_size, world_size)
-        self.hidden_size_per_attention_head = divide(hidden_size,
-                                                     num_attention_heads)
-        self.num_attention_heads_per_partition = divide(num_attention_heads,
-                                                        world_size)
-
-        # Strided linear layer.
-        self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size,
-                                                    stride=3,
-                                                    gather_output=False,
-                                                    init_method=init_method)
-
-        # Dropout. Note that for a single iteration, this layer will generate
-        # different outputs on different number of parallel partitions but
-        # on average it should not be partition dependent.
-        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
-
-        # Output.
-        self.dense = RowParallelLinear(hidden_size,
-                                       hidden_size,
-                                       input_is_parallel=True,
-                                       init_method=output_layer_init_method)
-        self.output_dropout = torch.nn.Dropout(output_dropout_prob)
-
-        if deepspeed.checkpointing.is_configured():
-            global get_cuda_rng_tracker, checkpoint
-            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
-            checkpoint = deepspeed.checkpointing.checkpoint
-
-        self.sparse_config = sparse_config
-
-        if finetune: 
-            # build new branch
-            self.query_key_value_plus = ColumnParallelLinear(hidden_size, 3*hidden_size,
-                                                    stride=3,
-                                                    gather_output=False,
-                                                    init_method=init_method)
-            self.dense_plus = RowParallelLinear(hidden_size,
-                                       hidden_size,
-                                       input_is_parallel=True,
-                                       init_method=output_layer_init_method)
-
-    def init_plus_from_old(self):
-        self.query_key_value_plus.weight.data.copy_(self.query_key_value.weight.data)
-        if hasattr(self.query_key_value_plus, 'bias') and hasattr(self.query_key_value, 'bias'):
-            self.query_key_value_plus.bias.data.copy_(self.query_key_value.bias.data)
-        
-        self.dense_plus.weight.data.copy_(self.dense.weight.data)
-        if hasattr(self.dense_plus, 'bias') and hasattr(self.dense, 'bias'):
-            self.dense_plus.bias.data.copy_(self.dense.bias.data)
-    def reset_sparse_config(self, config):
-        self.sparse_config = config
-
-    def _transpose_for_scores(self, tensor):
-        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
-        size [b, np, s, hn].
-        """
-        new_tensor_shape = tensor.size()[:-1] + \
-                           (self.num_attention_heads_per_partition,
-                            self.hidden_size_per_attention_head)
-        tensor = tensor.view(*new_tensor_shape)
-        return tensor.permute(0, 2, 1, 3)
-
-
-    def forward(self, hidden_states, mask, mem=None):
-        sparse_config = self.sparse_config
-        layout = sparse_config.layout
-        if sparse_config.sparse_type == 'cuda_2d':
-            assert hidden_states.size(1) == sparse_config.layout[-1]
-            # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
-            hidden_states_plus = hidden_states[:, layout[1]:]
-            hidden_states = hidden_states[:, :layout[1]]
-
-        mixed_raw_layer = self.query_key_value(hidden_states)
-        (mixed_query_layer,
-            mixed_key_layer,
-            mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
-        if mem is not None and len(mem) > 0:
-            memk, memv = split_tensor_along_last_dim(mem, 2)
-            mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
-            mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
-        
-        if sparse_config.sparse_type == 'cuda_2d':
-            mixed_raw_layer_plus = self.query_key_value_plus(hidden_states_plus)
-            q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer_plus, 3)
-
-        dropout_fn = self.attention_dropout if self.training else None
-
-        if sparse_config.sparse_type == 'standard':
-            query_layer = self._transpose_for_scores(mixed_query_layer)
-            
-            key_layer = self._transpose_for_scores(mixed_key_layer)
-            value_layer = self._transpose_for_scores(mixed_value_layer)
-            
-            context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, txt_len=layout[0] if not self.training else -1)
-            
-            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-            new_context_layer_shape = context_layer.size()[:-2] + \
-                                    (self.hidden_size_per_partition,)
-            context_layer = context_layer.view(*new_context_layer_shape)
-            
-        elif sparse_config.sparse_type == 'cuda_2d':
-            context_layer0, context_layer1 = sparse_attention_2d_light(
-                mixed_query_layer, mixed_key_layer, mixed_value_layer,
-                q1, k1, v1,
-                mask,
-                n_head=self.num_attention_heads_per_partition,
-                text_len=sparse_config.layout[0],
-                kernel_size=sparse_config.kernel_size,
-                kernel_size2=sparse_config.kernel_size2,
-                attention_dropout=dropout_fn,
-                text_start=(1-mask[...,-1,:]).sum().long().item()+1 if not self.training else -1,
-                layer_id=self.layer_id
-            )
-
-        if sparse_config.sparse_type == 'cuda_2d':
-            output_0 = self.dense(context_layer0)
-            output_1 = self.dense_plus(context_layer1)
-            output = torch.cat((output_0, output_1), dim=1)
-        else:
-            output = self.dense(context_layer)
-            
-        if self.training:
-            output = self.output_dropout(output)
-        
-        if mem is not None:
-            new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
-        else:
-            new_mem = None
-        return output, new_mem
-
-
-@torch.jit.script
-def gelu_impl(x):
-     """OpenAI's gelu implementation."""
-     return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
-                                        (1.0 + 0.044715 * x * x)))
-
-def gelu(x): 
-    return gelu_impl(x)
-
-
-class GPT2ParallelMLP(torch.nn.Module):
-    """MLP for GPT2.
-
-    MLP will take the input with h hidden state, project it to 4*h
-    hidden dimension, perform gelu transformation, and project the
-    state back into h hidden dimension. At the end, dropout is also
-    applied.
-
-    Arguments:
-        hidden_size: The hidden size of the self attention.
-        output_dropout_prob: dropout probability for the outputs
-                             after self attention and final output.
-        init_method: initialization method used for the weights. Note
-                     that all biases are initialized to zero and
-                     layernorm weight are initialized to one.
-        output_layer_init_method: output layer initialization. If None,
-                                  use `init_method`.
-    """
-
-    def __init__(self, hidden_size, output_dropout_prob, init_method,
-                 output_layer_init_method=None):
-        super(GPT2ParallelMLP, self).__init__()
-        # Set output layer initialization if not provided.
-        if output_layer_init_method is None:
-            output_layer_init_method = init_method
-        # Project to 4h.
-        self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4*hidden_size,
-                                                  gather_output=False,
-                                                  init_method=init_method)
-        # Project back to h.
-        self.dense_4h_to_h = RowParallelLinear(
-            4*hidden_size,
-            hidden_size,
-            input_is_parallel=True,
-            init_method=output_layer_init_method)
-        self.dropout = torch.nn.Dropout(output_dropout_prob)
-
-    def forward(self, hidden_states):
-        # [b, s, 4hp]
-
-        intermediate_parallel = self.dense_h_to_4h(hidden_states)
-        intermediate_parallel = gelu(intermediate_parallel)
-
-        # [b, s, h]
-        output = self.dense_4h_to_h(intermediate_parallel)
-        if self.training:
-            output = self.dropout(output)
-        return output
-
-
-class GPT2ParallelTransformerLayer(torch.nn.Module):
-    """A single layer transformer for GPT2.
-
-    We use the following notation:
-        h: hidden size
-        n: number of attention heads
-        b: batch size
-        s: sequence length
-    Transformore layer takes input with size [b, s, h] and returns an
-    output of the same size.
-
-    Arguments:
-        hidden_size: The hidden size of the self attention.
-        num_attention_heads: number of attention head in the self
-                             attention.
-        attention_dropout_prob: dropout probability of the attention
-                                score in self attention.
-        output_dropout_prob: dropout probability for the outputs
-                             after self attention and final output.
-        layernorm_epsilon: epsilon used in layernorm to avoid
-                           division by zero.
-        init_method: initialization method used for the weights. Note
-                     that all biases are initialized to zero and
-                     layernorm weight are initialized to one.
-        output_layer_init_method: output layers (attention output and
-                                  mlp output) initialization. If None,
-                                  use `init_method`.
-    """
-    def __init__(self,
-                 hidden_size,
-                 num_attention_heads,
-                 attention_dropout_prob,
-                 output_dropout_prob,
-                 layernorm_epsilon,
-                 init_method,
-                 layer_id,
-                 output_layer_init_method=None,
-                 sandwich_ln=True,
-                 sparse_config=argparse.Namespace(sparse_type='standard'),
-                 finetune=False
-                 ):
-        super(GPT2ParallelTransformerLayer, self).__init__()
-        # Set output layer initialization if not provided.
-        if output_layer_init_method is None:
-            output_layer_init_method = init_method
-        self.layer_id = layer_id
-
-        # Layernorm on the input data.
-        self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
-
-        # Self attention.
-        self.attention = GPT2ParallelSelfAttention(
-            hidden_size,
-            num_attention_heads,
-            attention_dropout_prob,
-            output_dropout_prob,
-            init_method,
-            layer_id,
-            output_layer_init_method=output_layer_init_method,
-            sparse_config=sparse_config,
-            finetune=finetune
-            )
-
-        # Layernorm on the input data.
-        self.post_attention_layernorm = LayerNorm(hidden_size,
-                                                  eps=layernorm_epsilon)
-        self.sandwich_ln = sandwich_ln
-        if sandwich_ln:
-            self.third_layernorm = LayerNorm(hidden_size,
-                                                    eps=layernorm_epsilon)
-            self.fourth_layernorm = LayerNorm(hidden_size,
-                                                    eps=layernorm_epsilon)
-
-        # MLP
-        self.mlp = GPT2ParallelMLP(
-            hidden_size,
-            output_dropout_prob,
-            init_method,
-            output_layer_init_method=output_layer_init_method)
-
-        self.sparse_config = sparse_config
-
-    def reset_sparse_config(self, config):
-            self.sparse_config = config
-            self.attention.reset_sparse_config(config)
-    
-    def forward(self, hidden_states, ltor_mask, mem=None):
-        # hidden_states: [b, s, h]
-        # ltor_mask: [1, 1, s, s]
-
-        # Layer norm at the begining of the transformer layer.
-        layernorm_output1 = self.input_layernorm(hidden_states)
-        # Self attention.
-        attention_output, qkv = self.attention(layernorm_output1, ltor_mask, mem)
-
-        # Third LayerNorm
-        if self.sandwich_ln:
-            attention_output = self.third_layernorm(attention_output)
-
-        # Residual connection.
-        layernorm_input = hidden_states + attention_output
-        # Layer norm post the self attention.
-        layernorm_output = self.post_attention_layernorm(layernorm_input)
-        # MLP.
-        mlp_output = self.mlp(layernorm_output)
-
-        # Fourth LayerNorm
-        if self.sandwich_ln:
-            mlp_output = self.fourth_layernorm(mlp_output)
-
-        # Second residual connection.
-        output = layernorm_input + mlp_output
-
-        return output, qkv
-
-def unscaled_init_method(sigma):
-    """Init method based on N(0, sigma)."""
-    def init_(tensor):
-        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
-
-    return init_
-
-
-def scaled_init_method(sigma, num_layers):
-    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
-    std = sigma / math.sqrt(2.0 * num_layers)
-    def init_(tensor):
-        return torch.nn.init.normal_(tensor, mean=0.0, std=std)
-
-    return init_
-
-
-class GPT2ParallelTransformer(torch.nn.Module):
-    """GPT-2 transformer.
-
-    This module takes input from embedding layer and it's output can
-    be used directly by a logit layer. It consists of L (num-layers)
-    blocks of:
-        layer norm
-        self attention
-        residual connection
-        layer norm
-        mlp
-        residual connection
-    followed by a final layer norm.
-
-    Arguments:
-        num_layers: Number of transformer layers.
-        hidden_size: The hidden size of the self attention.
-        num_attention_heads: number of attention head in the self
-                             attention.
-        attention_dropout_prob: dropout probability of the attention
-                                score in self attention.
-        output_dropout_prob: dropout probability for the outputs
-                             after self attention and final output.
-        checkpoint_activations: if True, checkpoint activations.
-        checkpoint_num_layers: number of layers to checkpoint. This
-                               is basically the chunk size in checkpoitning.
-        layernorm_epsilon: epsilon used in layernorm to avoid
-                           division by zero.
-        init_method_std: standard deviation of the init method which has
-                         the form N(0, std).
-        use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)
-                                            scaling for the output weights (
-                                            output of self attention and mlp).
-    """
-    def __init__(self,
-                 num_layers,
-                 hidden_size,
-                 num_attention_heads,
-                 max_sequence_length,
-                 max_memory_length,
-                 embedding_dropout_prob,
-                 attention_dropout_prob,
-                 output_dropout_prob,
-                 checkpoint_activations,
-                 checkpoint_num_layers=1,
-                 layernorm_epsilon=1.0e-5,
-                 init_method_std=0.02,
-                 use_scaled_init_for_output_weights=True,
-                 sandwich_ln=True,
-                 sparse_config=argparse.Namespace(sparse_type='standard'),
-                 finetune=False
-                 ):
-        super(GPT2ParallelTransformer, self).__init__()
-        # Store activation checkpoiting flag.
-        self.checkpoint_activations = checkpoint_activations
-        self.checkpoint_num_layers = checkpoint_num_layers
-        self.max_memory_length = max_memory_length
-        self.max_sequence_length = max_sequence_length
-
-        output_layer_init_method = None
-        if use_scaled_init_for_output_weights:
-            output_layer_init_method = scaled_init_method(init_method_std,
-                                                      num_layers)
-        # Embeddings dropout
-        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
-
-        # Position embedding (serial).
-        self.position_embeddings = torch.nn.Embedding(max_sequence_length,
-                                                        hidden_size)
-        # Initialize the position embeddings.
-        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
-
-        if finetune:
-            self.position_embeddings_plus = torch.nn.Embedding(4096, # FIXME
-                                                            hidden_size)
-            # Initialize the position embeddings.
-            torch.nn.init.normal_(self.position_embeddings_plus.weight, mean=0.0, std=init_method_std)
-
-        def get_layer(layer_id):
-            return GPT2ParallelTransformerLayer(
-                hidden_size,
-                num_attention_heads,
-                attention_dropout_prob,
-                output_dropout_prob,
-                layernorm_epsilon,
-                unscaled_init_method(init_method_std),
-                layer_id,
-                output_layer_init_method=output_layer_init_method,
-                sandwich_ln=sandwich_ln,
-                sparse_config=sparse_config,
-                finetune=finetune
-                )
-
-        # Transformer layers.
-        self.layers = torch.nn.ModuleList(
-            [get_layer(layer_id) for layer_id in range(num_layers)])
-
-        # Final layer norm before output.
-        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
-
-        if deepspeed.checkpointing.is_configured():
-            global get_cuda_rng_tracker, checkpoint
-            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
-            checkpoint = deepspeed.checkpointing.checkpoint
-        self.sparse_config = sparse_config
-
-    def init_plus_from_old(self):
-        self.position_embeddings_plus.weight.data.view(4, 1024, -1).copy_(self.position_embeddings.weight.data[-1024:]) # FIXME
-        for layer in self.layers:
-            layer.attention.init_plus_from_old()
-
-    def reset_sparse_config(self, config):
-            self.sparse_config = config
-            for layer in self.layers:
-                layer.reset_sparse_config(config)
-
-    def forward(self, hidden_states, position_ids, attention_mask, *mems):
-
-        batch_size, query_length = hidden_states.size()[:2]
-        memory_length = mems[0].size(1) if mems else 0
-        key_length = query_length + memory_length
-
-        # legacy
-        if isinstance(attention_mask, int) or attention_mask.numel() == 1:
-            # if given a int "sep", means the seperation of full attention part and single direction part
-            # attention mask is the beginning postion of B region, \in [0, query_len)
-            sep = attention_mask
-            # conventional transformer
-            def build_mask_matrix(query_length, key_length, sep):
-                m = torch.ones((1, query_length, key_length), device=hidden_states.device, dtype=hidden_states.dtype)
-                assert query_length <= key_length
-                m[0, :, -query_length:] = torch.tril(m[0, :, -query_length:])
-                m[0, :, :sep + (key_length - query_length)] = 1
-                m = m.unsqueeze(1)
-                return m
-            attention_mask = build_mask_matrix(query_length, key_length, sep)
-
-    
-        if self.sparse_config.sparse_type == 'cuda_2d':
-            position = position_ids[..., :self.sparse_config.layout[1]]
-            position_plus = position_ids[..., self.sparse_config.layout[1]:]
-            position_embeddings = torch.cat(
-                (self.position_embeddings(position), self.position_embeddings_plus(position_plus)), dim=-2)
-        else:
-            position_embeddings = self.position_embeddings(position_ids)
-        hidden_states = hidden_states + position_embeddings
-        hidden_states = self.embedding_dropout(hidden_states)
-
-        mem_layers = []
-        def custom(start, end):
-            def custom_forward(*inputs):
-                layers_ = self.layers[start:end]
-                x_, mask, mems_ = inputs[0], inputs[1], inputs[2:]
-            
-                for i, layer in enumerate(layers_):
-                    if mems_:
-                        mem_i_ = mems_[i]  
-                    elif self.max_memory_length > 0:
-                        mem_i_ = []
-                    else:
-                        mem_i_ = None
-                    x_, qkv = layer(x_, mask, mem=mem_i_)
-                    if self.max_memory_length > 0:
-                        mem_layers.append(qkv)
-                return x_
-            return custom_forward
-
-        attention_mask_saved = attention_mask
-        
-        if self.checkpoint_activations:
-            l = 0
-            num_layers = len(self.layers)
-            chunk_length = self.checkpoint_num_layers
-            while l < num_layers:
-                args = [hidden_states, attention_mask_saved]
-
-                if mems:
-                    args += mems[l: l + chunk_length]
-
-                hidden_states = checkpoint(custom(l, l + chunk_length), *args)
-                l += chunk_length
-        else:
-            for i, layer in enumerate(self.layers):
-                args = [hidden_states, attention_mask_saved]
-                if mems:
-                    mem_i = mems[i]  
-                elif self.max_memory_length > 0:
-                    mem_i = []
-                else:
-                    mem_i = None
-                hidden_states, qkv = layer(*args, mem=mem_i)
-                if self.max_memory_length > 0:
-                    mem_layers.append(qkv) 
-
-        # Final layer norm.
-        output = self.final_layernorm(hidden_states)
-
-        return (output, *mem_layers)
-        
-
-def _chunk(x, w, times):
-    '''convert into overlapping chunkings. Chunk size = times * w, overlap size = w
-    Args:
-        x: [b, np, s, hn]
-        ...
-    '''
-    s = x.size(2)
-    # x pad to [b, np, s+xx to k*w + w*(times-1), hn]
-    assert s % w == 0
-    npad = (times-1) * w
-    x = torch.nn.functional.pad(x, (0, 0, npad, 0), value=0)
-
-    x = x.view(x.size(0), x.size(1),  x.size(2) // w, w, x.size(3))
-
-    chunk_size = list(x.size())
-    chunk_stride = list(x.stride())
-
-    chunk_size[2] = chunk_size[2] - times + 1
-
-    chunk_size[3] = w * times
-
-    return x.as_strided(size=chunk_size, stride=chunk_stride)
-
-def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, layer_id = -1, txt_len=-1):
-    # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
-    # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 
-
-    if len(attention_mask.shape) == 3:
-        attention_mask = attention_mask.unsqueeze(1)
-    # Raw attention scores. [b, np, s, s]
-    attention_scores = torch.matmul(query_layer / math.sqrt(query_layer.shape[-1]), key_layer.transpose(-1, -2))
-    
-    # Apply the left to right attention mask.
-    if attention_mask.shape[2] > 1:
-        attention_scores = torch.mul(attention_scores, attention_mask) - \
-                    10000.0 * (1.0 - attention_mask)
-    
-    # Attention probabilities [b, np, s, s]
-    attention_probs = F.softmax(attention_scores, dim=-1)
-
-    if txt_len > 0:
-        t = key_layer.shape[-2] - txt_len - 1
-        if t // 32 <= 32:
-            # line = attention_probs[..., :, 1:txt_len].max(dim=-1, keepdim=True)[0]
-            # tmask = attention_probs[..., :, 1:txt_len] >= line
-            attention_probs[..., :, 1:txt_len] *= 6 if txt_len <= 10 else 4
-            attention_probs /= attention_probs.sum(dim=-1, keepdim=True)[0]
-
-    if attention_dropout is not None:
-        with get_cuda_rng_tracker().fork():
-            attention_probs = attention_dropout(attention_probs)
-    # Context layer.
-    # [b, np, s, hn]
-
-    context_layer = torch.matmul(attention_probs, value_layer)
-    return context_layer
-
-
-def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len=64, kernel_size=9, kernel_size2=7, attention_dropout=None, text_start = -1, layer_id=-1, **kwargs):
-    '''
-    q0, k0, v0: [batch_size, 1088, hidden_size]
-    q1, k1, v1: [batch_size, 4096, h2]
-    n_head: int
-    attention_mask: [batch_size, 1088, 1088]
-    '''
-    from .local_attention_function import f_similar, f_weighting
-    b, s0, h0 = q0.shape
-    b, s1, h1 = q1.shape
-    assert v1.shape[-1] == h0, 'q1, k1 can be smaller, but v1 cannot.'
-    h = h0 // n_head
-    l0, l1 = int(math.sqrt(s0-text_len)+0.0001), int(math.sqrt(s1)+0.0001)
-
-    q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
-    v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
-    k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
-    # standard attention for level 0
-    attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
-    attention_scores = torch.mul(attention_scores, attention_mask) - \
-                    10000.0 * (1.0 - attention_mask)
-    attention_probs0 = F.softmax(attention_scores, dim=-1)
-    if text_start > 0:
-        attention_probs0[..., :, text_start:text_len-2] *= 1
-        attention_probs0 /= attention_probs0.sum(dim=-1, keepdim=True)[0]
-    # local attention for level 1
-    q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
-    k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
-    v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
-    scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    
-    # attention_probs1 = F.softmax(scores_1_to_1, dim=-1)
-
-    # cross attention
-    k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
-    scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
-    scores_1 = torch.cat(
-        (
-            scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]),
-            scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
-        ),
-        dim=-1)
-    attention_probs1 = F.softmax(scores_1, dim=-1)
-
-    if attention_dropout is not None:
-        with get_cuda_rng_tracker().fork():
-            attention_probs0 = attention_dropout(attention_probs0)
-            attention_probs1 = attention_dropout(attention_probs1)
-        
-    # weighting for level 0
-    context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
-    # weighting for level 1
-    probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
-    context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
-    context1 = context1_to_1.view(b, n_head * h, l1**2)
-    # weighting for cross attention
-    probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
-    v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
-    context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
-    context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
-    context1 = context1 + context1_to_0
-    return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
\ No newline at end of file
diff --git a/mpu/transformer.py b/mpu/transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..ca2e9b94d9a84f1fd059eb2e1d2c3832ddb430d1
--- /dev/null
+++ b/mpu/transformer.py
@@ -0,0 +1,412 @@
+# coding=utf-
+# rewritten, Copyright (c) 2021, Ming Ding.  All rights reserved.
+# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Transformer."""
+
+import math
+import copy
+import torch
+import torch.nn.functional as F
+from apex.normalization.fused_layer_norm import FusedLayerNorm
+
+from .initialize import get_model_parallel_world_size
+from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
+from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
+
+import deepspeed
+
+from .random import checkpoint
+from .random import get_cuda_rng_tracker
+
+from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
+from .utils import split_tensor_along_last_dim
+
+class LayerNorm(FusedLayerNorm):
+    def __init__(self, *args, pb_relax=False, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.pb_relax = pb_relax
+    def forward(self, x):
+        if not self.pb_relax:
+            return super().forward(x)
+        return super().forward(x / (x.abs().max().detach()/8))
+        
+def standard_attention(query_layer, key_layer, value_layer, attention_mask,
+                    attention_dropout=None, log_attention_weights=None):
+    # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
+    # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 
+
+    attention_scores = torch.matmul(
+        query_layer / math.sqrt(query_layer.shape[-1]),
+        key_layer.transpose(-1, -2)
+    )
+    
+    if attention_mask.shape[-2] > 1: # if auto-regressive, skip
+        attention_scores = torch.mul(attention_scores, attention_mask) - \
+                    10000.0 * (1.0 - attention_mask)
+    if log_attention_weights is not None:
+        attention_scores += log_attention_weights
+    
+    attention_probs = F.softmax(attention_scores, dim=-1)
+
+    if attention_dropout is not None:
+        with get_cuda_rng_tracker().fork():
+            attention_probs = attention_dropout(attention_probs)
+
+    context_layer = torch.matmul(attention_probs, value_layer)
+    return context_layer
+
+class SelfAttention(torch.nn.Module):
+    def __init__(self, hidden_size, num_attention_heads,
+                attention_dropout_prob, output_dropout_prob,
+                init_method, layer_id, output_layer_init_method=None,
+                hooks={}):
+        super(SelfAttention, self).__init__()
+        # Set output layer initialization if not provided.
+        if output_layer_init_method is None:
+            output_layer_init_method = init_method
+        self.hooks = hooks
+        self.layer_id = layer_id
+        # Per attention head and per partition values.
+        world_size = get_model_parallel_world_size()
+        self.hidden_size_per_partition = divide(hidden_size, world_size)
+        self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
+        self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
+
+        # Strided linear layer.
+        self.query_key_value = ColumnParallelLinear(
+            hidden_size, 
+            3*hidden_size,
+            stride=3,
+            gather_output=False,
+            init_method=init_method
+        )
+        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
+
+        self.dense = RowParallelLinear(
+            hidden_size,
+            hidden_size,
+            input_is_parallel=True,
+            init_method=output_layer_init_method
+        )
+        self.output_dropout = torch.nn.Dropout(output_dropout_prob)
+
+
+    def _transpose_for_scores(self, tensor):
+        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
+        size [b, np, s, hn].
+        """
+        new_tensor_shape = tensor.size()[:-1] + \
+                            (self.num_attention_heads_per_partition,
+                            self.hidden_size_per_attention_head)
+        tensor = tensor.view(*new_tensor_shape)
+        return tensor.permute(0, 2, 1, 3)
+
+    def forward(self, hidden_states, mask, *other_tensors):
+        if 'attention_forward' in self.hooks:
+            return self.hooks['attention_forward'](hidden_states, mask, *other_tensors)
+        else:
+            mixed_raw_layer = self.query_key_value(hidden_states)
+            (mixed_query_layer,
+                mixed_key_layer,
+                mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
+
+            dropout_fn = self.attention_dropout if self.training else None
+
+            query_layer = self._transpose_for_scores(mixed_query_layer)
+            key_layer = self._transpose_for_scores(mixed_key_layer)
+            value_layer = self._transpose_for_scores(mixed_value_layer)
+            
+            context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
+            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+            context_layer = context_layer.view(*new_context_layer_shape)
+            output = self.dense(context_layer)
+            
+            if self.training:
+                output = self.output_dropout(output)
+            
+            return output, None
+
+
+class MLP(torch.nn.Module):
+    def __init__(self, hidden_size, output_dropout_prob, init_method,
+                output_layer_init_method=None, hooks={}):
+        super(MLP, self).__init__()
+        # Set output layer initialization if not provided.
+        if output_layer_init_method is None:
+            output_layer_init_method = init_method
+        self.hooks = hooks
+        # Project to 4h.
+        self.dense_h_to_4h = ColumnParallelLinear(
+            hidden_size,
+            4*hidden_size,
+            gather_output=False,
+            init_method=init_method
+        )
+        # Project back to h.
+        self.dense_4h_to_h = RowParallelLinear(
+            4*hidden_size,
+            hidden_size,
+            input_is_parallel=True,
+            init_method=output_layer_init_method
+        )
+        self.dropout = torch.nn.Dropout(output_dropout_prob)
+
+    def forward(self, hidden_states, *other_tensors):
+        if 'mlp_forward' in self.hooks:
+            output = self.hooks['mlp_forward'](hidden_states, *other_tensors, layer_id=self.layer_id)
+        else:
+            intermediate_parallel = self.dense_h_to_4h(hidden_states)
+            intermediate_parallel = gelu(intermediate_parallel)
+            output = self.dense_4h_to_h(intermediate_parallel)
+            
+        if self.training:
+            output = self.dropout(output)
+        return output
+
+
+class BaseTransformerLayer(torch.nn.Module):
+    """A single layer transformer for GPT2.
+
+    We use the following notation:
+        h: hidden size
+        n: number of attention heads
+        b: batch size
+        s: sequence length
+    Transformore layer takes input with size [b, s, h] and returns an
+    output of the same size.
+
+    Arguments:
+        hidden_size: The hidden size of the self attention.
+        num_attention_heads: number of attention head in the self
+                             attention.
+        attention_dropout_prob: dropout probability of the attention
+                                score in self attention.
+        output_dropout_prob: dropout probability for the outputs
+                             after self attention and final output.
+        layernorm_epsilon: epsilon used in layernorm to avoid
+                           division by zero.
+        init_method: initialization method used for the weights. Note
+                     that all biases are initialized to zero and
+                     layernorm weight are initialized to one.
+        output_layer_init_method: output layers (attention output and
+                                  mlp output) initialization. If None,
+                                  use `init_method`.
+    """
+    def __init__(
+        self,
+        hidden_size,
+        num_attention_heads,
+        attention_dropout_prob,
+        output_dropout_prob,
+        layernorm_epsilon,
+        init_method,
+        layer_id,
+        output_layer_init_method=None,
+        sandwich_ln=True,
+        hooks={}
+    ):
+        super(BaseTransformerLayer, self).__init__()
+        # Set output layer initialization if not provided.
+        if output_layer_init_method is None:
+            output_layer_init_method = init_method
+        self.layer_id = layer_id
+        self.hooks = hooks
+
+        # Layernorm on the input data.
+        self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+
+        # Self attention.
+        self.attention = SelfAttention(
+            hidden_size,
+            num_attention_heads,
+            attention_dropout_prob,
+            output_dropout_prob,
+            init_method,
+            layer_id,
+            output_layer_init_method=output_layer_init_method,
+            hooks=hooks
+        )
+
+        # Layernorm on the input data.
+        self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+        self.sandwich_ln = sandwich_ln
+        if sandwich_ln:
+            self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+            self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+
+        # MLP
+        self.mlp = MLP(
+            hidden_size,
+            output_dropout_prob,
+            init_method,
+            output_layer_init_method=output_layer_init_method,
+            hooks=hooks
+        )
+    
+    def forward(self, hidden_states, mask, *other_tensors):
+        '''
+            hidden_states: [batch, seq_len, hidden_size]
+            mask: [(1, 1), seq_len, seq_len]
+        '''
+
+        # Layer norm at the begining of the transformer layer.
+        layernorm_output1 = self.input_layernorm(hidden_states)
+        # Self attention.
+        attention_output, output_this_layer = self.attention(layernorm_output1, mask, *other_tensors)
+
+        # Third LayerNorm
+        if self.sandwich_ln:
+            attention_output = self.third_layernorm(attention_output)
+
+        # Residual connection.
+        layernorm_input = hidden_states + attention_output
+        # Layer norm post the self attention.
+        layernorm_output = self.post_attention_layernorm(layernorm_input)
+        # MLP.
+        mlp_output = self.mlp(layernorm_output)
+
+        # Fourth LayerNorm
+        if self.sandwich_ln:
+            mlp_output = self.fourth_layernorm(mlp_output, *other_tensors)
+
+        # Second residual connection.
+        output = layernorm_input + mlp_output
+
+        return output, output_this_layer # temporally, output_this_layer is only from attention
+
+class BaseTransformer(torch.nn.Module):
+    def __init__(self,
+                 num_layers,
+                 vocab_size,
+                 hidden_size,
+                 num_attention_heads,
+                 max_sequence_length,
+                 embedding_dropout_prob,
+                 attention_dropout_prob,
+                 output_dropout_prob,
+                 checkpoint_activations,
+                 checkpoint_num_layers=1,
+                 layernorm_epsilon=1.0e-5,
+                 init_method_std=0.02,
+                 sandwich_ln=True,
+                 parallel_output=True,
+                 hooks={}
+                 ):
+        super(BaseTransformer, self).__init__()
+        if deepspeed.checkpointing.is_configured():
+            global get_cuda_rng_tracker, checkpoint
+            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
+            checkpoint = deepspeed.checkpointing.checkpoint
+        
+        # recording parameters
+        self.parallel_output = parallel_output
+        self.checkpoint_activations = checkpoint_activations
+        self.checkpoint_num_layers = checkpoint_num_layers
+        self.max_sequence_length = max_sequence_length
+        self.hooks = copy.copy(hooks) # hooks will be updated each forward
+        
+        # create embedding parameters
+        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+        
+        self.word_embeddings = VocabParallelEmbedding(
+            vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
+        
+        self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
+        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
+
+        # create all layers
+        self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
+        self.init_method = unscaled_init_method(init_method_std)
+        def get_layer(layer_id):
+            return BaseTransformerLayer(
+                hidden_size,
+                num_attention_heads,
+                attention_dropout_prob,
+                output_dropout_prob,
+                layernorm_epsilon,
+                self.init_method,
+                layer_id,
+                output_layer_init_method=self.output_layer_init_method,
+                sandwich_ln=sandwich_ln,
+                hooks=hooks
+                )
+        self.layers = torch.nn.ModuleList(
+            [get_layer(layer_id) for layer_id in range(num_layers)])
+
+        # Final layer norm before output.
+        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+
+    def forward(self, input_ids, position_ids, attention_mask, *other_tensors):
+        # sanity check 
+        assert len(input_ids.shape) == 2 
+        batch_size, query_length = input_ids.shape
+        assert len(position_ids.shape) <= 2
+        assert position_ids.shape[-1] == query_length
+        assert len(attention_mask.shape) == 2 or \
+            len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
+
+        # embedding part
+        if 'word_embedding_forward' in self.hooks:
+            hidden_states = self.hooks['word_embedding_forward'](input_ids, *other_tensors)
+        else: # default
+            hidden_states = self.word_embeddings(input_ids)
+            
+        if 'position_embedding_forward' in self.hooks:
+            position_embeddings = self.hooks['position_embedding_forward'](position_ids, *other_tensors)
+        else:
+            position_embeddings = self.position_embeddings(position_ids)    
+        hidden_states = hidden_states + position_embeddings
+        hidden_states = self.embedding_dropout(hidden_states)
+
+        # define custom_forward for checkpointing
+        output_per_layers = []
+        if self.checkpoint_activations:
+            def custom(start, end):
+                def custom_forward(*inputs):
+                    layers_ = self.layers[start:end]
+                    x_, mask, *other_tensors = inputs[0], inputs[1], inputs[2:]
+                    for i, layer in enumerate(layers_):
+                        x_, output_this_layer = layer(x_, mask, *other_tensors)
+                        output_per_layers.append(output_this_layer)
+                    return x_
+                return custom_forward
+        
+            l, num_layers = 0, len(self.layers)
+            chunk_length = self.checkpoint_num_layers
+            while l < num_layers:
+                args = [hidden_states, attention_mask, *other_tensors]
+                hidden_states = checkpoint(custom(l, l + chunk_length), *args)
+                l += chunk_length
+        else:
+            for i, layer in enumerate(self.layers):
+                args = [hidden_states, attention_mask, *other_tensors]
+                hidden_states, output_this_layer = layer(*args, *other_tensors)
+                output_per_layers.append(output_this_layer) 
+
+        # Final layer norm.
+        logits = self.final_layernorm(hidden_states)
+        
+        if 'final_forward' in self.hooks:
+            logits_parallel = self.hooks['final_forward'](logits, *other_tensors)
+        else:
+            logits_parallel = copy_to_model_parallel_region(logits)
+            logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
+
+        if self.parallel_output:
+            return (logits_parallel, *output_per_layers)
+        return (gather_from_model_parallel_region(logits_parallel), *output_per_layers)
+        
diff --git a/mpu/utils.py b/mpu/utils.py
index d9b1a8d0d8f57dd5bc172d3b1222596297bfed6d..c83f501889ccabb33cce33260f7d4fee9eadcab7 100755
--- a/mpu/utils.py
+++ b/mpu/utils.py
@@ -70,15 +70,29 @@ class VocabUtility:
         return VocabUtility.vocab_range_from_per_partition_vocab_size(
             per_partition_vocab_size, rank, world_size)
 
-def split_out_sums(x, BLOCK_SIZE=32, all_ret=False):
-    b, L = x.shape[:2]
-    rs = x.shape[2:]
-    x = x.view(b, L // BLOCK_SIZE, BLOCK_SIZE, *rs)
-    oris, sums = x.split([BLOCK_SIZE-1, 1], dim=2)
-    if all_ret:
-        return oris.reshape(b, -1, *rs), sums.reshape(b, -1, *rs)
-    else: 
-        return sums.reshape(b, -1, *rs)
-
 def sqrt(x):
-    return int(math.sqrt(x) + 1e-4)
\ No newline at end of file
+    return int(math.sqrt(x) + 1e-4)
+
+def unscaled_init_method(sigma):
+    """Init method based on N(0, sigma)."""
+    def init_(tensor):
+        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
+
+    return init_
+
+def scaled_init_method(sigma, num_layers):
+    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
+    std = sigma / math.sqrt(2.0 * num_layers)
+    def init_(tensor):
+        return torch.nn.init.normal_(tensor, mean=0.0, std=std)
+
+    return init_
+
+@torch.jit.script
+def gelu_impl(x):
+     """OpenAI's gelu implementation."""
+     return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
+                                        (1.0 + 0.044715 * x * x)))
+
+def gelu(x): 
+    return gelu_impl(x)
\ No newline at end of file
diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py
index 8fb6d3b355c095d1616765b6a60f6969f9620b76..c8ac52a6bf098e938e033b8087e8caf2342d1fa0 100755
--- a/pretrain_gpt2.py
+++ b/pretrain_gpt2.py
@@ -13,11 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""Pretrain GPT2"""
-
-# Flag to use Pytorch ddp which uses overlapping communication and computation.
-USE_TORCH_DDP = True
-
 from datetime import datetime
 import os
 import random
@@ -29,24 +24,15 @@ import torch
 import deepspeed
 from contextlib import ExitStack
 from arguments import get_args
-from fp16 import FP16_Module
-from fp16 import FP16_Optimizer
 from learning_rates import AnnealingLR
-from model import GPT2Model
-from model import gpt2_get_params_for_weight_decay_optimization
 
-if USE_TORCH_DDP:
-    from model import PyTorchDistributedDataParallel as DDP
-else:
-    from model import DistributedDataParallel as DDP
 import mpu
-from apex.optimizers import FusedAdam as Adam
+from mpu import GPT2ParallelTransformer
 from utils import Timers
 from utils import save_checkpoint
 from utils import load_checkpoint
 from utils import report_memory
 from utils import print_args
-from utils import print_params_min_max_norm
 from utils import print_rank_0
 from utils import get_sample_writer
 import torch.distributed as dist
@@ -60,7 +46,7 @@ def get_model(args, sparse_config=None):
 
     print_rank_0('building CogView2 model ...')
     ml = args.max_position_embeddings
-    model = GPT2Model(num_layers=args.num_layers,
+    model = GPT2ParallelTransformer(num_layers=args.num_layers,
                       vocab_size=args.vocab_size,
                       hidden_size=args.hidden_size,
                       num_attention_heads=args.num_attention_heads,
@@ -90,26 +76,47 @@ def get_model(args, sparse_config=None):
     model.cuda(torch.cuda.current_device())
 
     # Fp16 conversion.
-    if args.fp16:
-        model = FP16_Module(model)
+    # if args.fp16:
+    #     model = FP16_Module(model)
 
     # Wrap model for distributed training.
-    if not args.deepspeed:
-        if USE_TORCH_DDP:
-            i = torch.cuda.current_device()
-            model = DDP(model, device_ids=[i], output_device=i,
-                        process_group=mpu.get_data_parallel_group())
-        else:
-            model = DDP(model)
+    # if not args.deepspeed:
+    #     if USE_TORCH_DDP:
+    #         i = torch.cuda.current_device()
+    #         model = DDP(model, device_ids=[i], output_device=i,
+    #                     process_group=mpu.get_data_parallel_group())
+    #     else:
+    #         model = DDP(model)
 
     return model
 
 
+def gpt2_get_params_for_weight_decay_optimization(module):
+    
+    weight_decay_params = {'params': []}
+    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
+    for module_ in module.modules():
+        if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
+            no_weight_decay_params['params'].extend(
+                [p for p in list(module_._parameters.values())
+                 if p is not None and p.requires_grad])
+        else:
+            weight_decay_params['params'].extend(
+                [p for n, p in list(module_._parameters.items())
+                 if p is not None and n != 'bias' and p.requires_grad])
+            no_weight_decay_params['params'].extend(
+                [p for n, p in list(module_._parameters.items())
+                 if p is not None and n == 'bias' and p.requires_grad])
+    return weight_decay_params, no_weight_decay_params
+
+
 def get_optimizer_param_groups(model):
     # Build parameter groups (weight decay and non-decay).
-    while isinstance(model, (DDP, FP16_Module)):
+    while hasattr(model, 'module'):
+        print(model)
         model = model.module
-    param_groups = gpt2_get_params_for_weight_decay_optimization(model)
+        
+    param_groups = gpt2_get_params_for_weight_decay_optimization(model) # TODO move to here
 
     # Add model parallel attribute if it is not set.
     for param_group in param_groups:
@@ -119,43 +126,6 @@ def get_optimizer_param_groups(model):
 
     return param_groups
 
-
-def get_optimizer(param_groups, args):
-    """Set up the optimizer."""
-    if args.cpu_optimizer:
-        #Apex FusedAdam uses decoupled weight decay so use the same here
-        if args.cpu_torch_adam:
-            cpu_adam_optimizer = torch.optim.AdamW
-        else:
-            #TODO add option for decoupled weight decay in DeepCPUAdam
-            from deepspeed.ops.adam import DeepSpeedCPUAdam
-            cpu_adam_optimizer = DeepSpeedCPUAdam
-        optimizer = cpu_adam_optimizer(param_groups,
-                        lr=args.lr, weight_decay=args.weight_decay)
-    else:
-        # Use FusedAdam.
-        optimizer = Adam(param_groups,
-                         lr=args.lr, weight_decay=args.weight_decay)
-
-    print(f'Optimizer = {optimizer.__class__.__name__}')
-    if hasattr(args, "deepspeed") and args.deepspeed:
-        raise NotImplementedError
-        # fp16 wrapper is not required for DeepSpeed.
-        # return optimizer
-
-    # Wrap into fp16 optimizer.
-    if args.fp16:
-        optimizer = FP16_Optimizer(optimizer,
-                                   static_loss_scale=args.loss_scale,
-                                   dynamic_loss_scale=args.dynamic_loss_scale,
-                                   dynamic_loss_args={
-                                       'scale_window': args.loss_scale_window,
-                                       'min_scale': args.min_scale,
-                                       'delayed_shift': args.hysteresis})
-
-    return optimizer
-
-
 def get_learning_rate_scheduler(optimizer, args):
     """Build the learning rate scheduler."""
 
@@ -184,10 +154,11 @@ def setup_model_and_optimizer(args):
 
     model = get_model(args)
 
-    if args.finetune:
+    if args.finetune: # TODO
         model.requires_grad_(False)
         for name, param in model.named_parameters():
-            if name.find('_plus') > 0:
+            # if name.find('_plus') > 0:
+            if name.find('query_key_value') >= 0 or name.find('attention.dense') >= 0 or name.find('position_embeddings') >= 0:
                 param.requires_grad_(True)
 
     param_groups = get_optimizer_param_groups(model)
@@ -195,7 +166,6 @@ def setup_model_and_optimizer(args):
     if args.train_data is not None:
         if args.deepspeed:
             print_rank_0("DeepSpeed is enabled.")
-
             model, optimizer, _, _ = deepspeed.initialize(
                 model=model,
                 model_parameters=param_groups,
@@ -204,7 +174,7 @@ def setup_model_and_optimizer(args):
                 dist_init_required=False
             )
         else:
-            optimizer = get_optimizer(param_groups, args)
+            raise ValueError('Currently, we only support training with deepspeed.')
         lr_scheduler = get_learning_rate_scheduler(optimizer, args)
     else:
         optimizer, lr_scheduler = None, None
@@ -374,26 +344,26 @@ def backward_step(optimizer, model, lm_loss, args, timers):
         # DeepSpeed backward propagation already addressed all reduce communication.
         # Reset the timer to avoid breaking timer logs below.
         timers('allreduce').reset()
-    else:
-        if not USE_TORCH_DDP:
-            timers('allreduce').start()
-            model.allreduce_params(reduce_after=False,
-                                   fp32_allreduce=args.fp32_allreduce)
-            timers('allreduce').stop()
+    # else:
+    #     if not USE_TORCH_DDP:
+    #         timers('allreduce').start()
+    #         model.allreduce_params(reduce_after=False,
+    #                                fp32_allreduce=args.fp32_allreduce)
+    #         timers('allreduce').stop()
 
     lm_loss_reduced = reduced_losses
 
     # Update master gradients.
-    if not args.deepspeed:
-        if args.fp16:
-            optimizer.update_master_grads()
+    # if not args.deepspeed:
+    #     if args.fp16:
+    #         optimizer.update_master_grads()
 
-        # Clipping gradients helps prevent the exploding gradient.
-        if args.clip_grad > 0:
-            if not args.fp16:
-                mpu.clip_grad_norm(model.parameters(), args.clip_grad)
-            else:
-                optimizer.clip_master_grads(args.clip_grad)
+    #     # Clipping gradients helps prevent the exploding gradient.
+    #     if args.clip_grad > 0:
+    #         if not args.fp16:
+    #             mpu.clip_grad_norm(model.parameters(), args.clip_grad)
+    #         else:
+    #             optimizer.clip_master_grads(args.clip_grad)
 
     return lm_loss_reduced
 
@@ -545,14 +515,14 @@ def train(model, optimizer, lr_scheduler,
             if report_memory_flag:
                 report_memory('after {} iterations'.format(args.iteration))
                 report_memory_flag = False
-            if USE_TORCH_DDP:
-                timers.log(['forward', 'backward', 'optimizer',
-                            'batch generator', 'data loader'],
-                           normalizer=args.log_interval)
-            else:
-                timers.log(['forward', 'backward', 'allreduce', 'optimizer',
+            # if USE_TORCH_DDP:
+            #     timers.log(['forward', 'backward', 'optimizer',
+            #                 'batch generator', 'data loader'],
+            #                normalizer=args.log_interval)
+            # else:
+            timers.log(['forward', 'backward', 'allreduce', 'optimizer',
                             'batch generator', 'data loader'],
-                           normalizer=args.log_interval)
+                        normalizer=args.log_interval)
         # Checkpointing
         if args.save and args.save_interval and args.iteration % args.save_interval == 0:
             save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args)
@@ -613,9 +583,9 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
                 deepspeed.checkpointing.reset()
 
             # Reduce across processes.
-            if isinstance(model, DDP):
-                torch.distributed.all_reduce(lm_loss.data)
-                lm_loss.data = lm_loss.data / args.world_size
+            # if isinstance(model, DDP):
+            #     torch.distributed.all_reduce(lm_loss.data)
+            #     lm_loss.data = lm_loss.data / args.world_size
 
             total_lm_loss += lm_loss.data.detach().float().item()
 
diff --git a/utils.py b/utils.py
index 3256dce3e6fc1daf6df72aaae7f15d1ca4916de0..8ae815c4582a97557d80d29afceb2b96cea788a4 100755
--- a/utils.py
+++ b/utils.py
@@ -22,7 +22,6 @@ import numpy as np
 import torch
 
 from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
-from fp16 import FP16_Optimizer
 import mpu
 import model
 from tensorboardX import SummaryWriter
@@ -53,27 +52,6 @@ def print_args(args):
         dots = '.' * (29 - len(arg))
         print('  {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
 
-
-def print_params_min_max_norm(optimizer, iteration):
-    """Print min, max, and norm of all parameters."""
-    index = 0
-    rank = torch.distributed.get_rank()
-    string = 'iteration, rank, index, model-parallel,min, max, norm\n'
-    optimizer_ = optimizer
-    if isinstance(optimizer, FP16_Optimizer):
-        optimizer_ = optimizer.optimizer
-    for param_group in optimizer_.param_groups:
-        for param in param_group['params']:
-            index += 1
-            min_ = param.data.min()
-            max_ = param.data.max()
-            norm = param.data.norm()
-            string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
-                iteration, rank, index, int(param.model_parallel))
-            string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
-    print(string, flush=True)
-
-
 class Timers:
     """Group of timers."""