Skip to content
Snippets Groups Projects
Commit 5db7dfe4 authored by Ming Ding's avatar Ming Ding
Browse files

tmp save3

parent 18be5e6a
No related branches found
No related tags found
No related merge requests found
......@@ -41,7 +41,8 @@ class BaseModel(torch.nn.Module):
self.mixins = torch.nn.ModuleList()
def reinit(self):
for m in self.mixins:
# if some mixins are loaded, overrides this function
for m in self.mixins:
m.reinit(self.transformer)
def forward(self, *args, **kwargs):
......
# -*- encoding: utf-8 -*-
'''
@File : gpt2_modeling.py
@Time : 2021/10/02 00:37:22
@File : cached_autoregressive_model.py
@Time : 2021/10/02 01:36:24
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
......
# -*- encoding: utf-8 -*-
'''
@File : cuda2d_model.py
@Time : 2021/10/02 01:36:32
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import torch.nn.functional as F
from .base_model import BaseModel
from .mixins import PositionEmbeddingMixin, AttentionMixin
from mpu.transformer import split_tensor_along_last_dim
from mpu.local_attention_function import f_similar, f_weighting
from mpu.random import get_cuda_rng_tracker
from mpu.utils import sqrt
class Cuda2dModel(BaseModel):
def __init__(self, args, transformer=None):
super().__init__(args, transformer=transformer)
additional_seqlen = args.new_sequence_length - args.max_sequence_length
self.mixins.append(PositionEmbeddingMixin(
additional_seqlen, args.hidden_size
))
self.mixins.append(AttentionMixin(
num_layers=args.num_layers,
hidden_size=args.hidden_size
))
self.layout = args.layout
# [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
self.kernel_size = args.kernel_size
self.kernel_size2 = args.kernel_size2
self.log_attention_weights = None
def position_embedding_forward(self, position_ids, *other_tensors):
position = position_ids[..., :self.layout[1]]
position_plus = position_ids[..., self.layout[1]:]
position_embeddings = torch.cat(
(
self.transformer.position_embeddings(position),
self.mixins[0].position_embeddings(position_plus)
),
dim=-2
)
return position_embeddings
def attention_forward(self, hidden_states, mask, *other_tensors, layer_id=None):
attn_module = self.transformer.layers[layer_id].attention
# attention_plus on all layers
query_key_value_plus = self.mixins[1].query_key_value[layer_id]
dense_plus = self.mixins[1].dense[layer_id]
# split two parts
hidden_states_plus = hidden_states[:, self.layout[1]:]
hidden_states = hidden_states[:, :self.layout[1]]
# base model qkv
mixed_raw_layer = attn_module.query_key_value(hidden_states)
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
# cuda2d model qkv
mixed_raw_layer = query_key_value_plus(hidden_states_plus)
q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3)
dropout_fn = attn_module.attention_dropout if self.training else None
# cuda2d attention
context_layer0, context_layer1 = sparse_attention_2d_light(
q0, k0, v0,
q1, k1, v1,
mask,
n_head=attn_module.num_attention_heads_per_partition,
text_len=self.layout[0],
kernel_size=self.kernel_size,
kernel_size2=self.kernel_size2,
attention_dropout=dropout_fn,
log_attention_weights=self.log_attention_weights
)
output_0 = attn_module.dense(context_layer0)
output_1 = dense_plus(context_layer1)
output = torch.cat((output_0, output_1), dim=1)
return output
def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, **kwargs):
'''
q0, k0, v0: [batch_size, 1088, hidden_size]
q1, k1, v1: [batch_size, 4096, h2]
n_head: int
attention_mask: [batch_size, 1088, 1088]
'''
b, s0, h0 = q0.shape
b, s1, h1 = q1.shape
h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1)
q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
# standard attention for level 0
attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
if log_attention_weights is not None:
attention_scores += log_attention_weights
attention_scores = torch.mul(attention_scores, attention_mask) - \
10000.0 * (1.0 - attention_mask)
attention_probs0 = F.softmax(attention_scores, dim=-1)
# local attention for level 1
q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)
# cross attention
k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
scores_1 = torch.cat(
(
scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]),
scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
),
dim=-1)
attention_probs1 = F.softmax(scores_1, dim=-1)
if attention_dropout is not None:
with get_cuda_rng_tracker().fork():
attention_probs0 = attention_dropout(attention_probs0)
attention_probs1 = attention_dropout(attention_probs1)
# weighting for level 0
context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
# weighting for level 1
probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
context1 = context1_to_1.view(b, n_head * h, l1**2)
# weighting for cross attention
probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
context1 = context1 + context1_to_0
return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
\ No newline at end of file
......@@ -25,12 +25,12 @@ class BaseMixin(torch.nn.Module):
pass
class PositionEmbeddingMixin(BaseMixin):
def __init__(self, new_sequence_length, hidden_size,
def __init__(self, additional_sequence_length, hidden_size,
init_method_std=0.02, reinit_slice=(-1024, None)
):
super(PositionEmbeddingMixin, self).__init__()
self.reinit_slice = reinit_slice
self.position_embeddings = torch.nn.Embedding(new_sequence_length, hidden_size)
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
def reinit(self, transformer, *pre_mixins):
old_weights = transformer.position_embeddings.weight.data[self.reinit_slice]
......
......@@ -62,82 +62,3 @@ f_similar = similarFunction.apply
f_weighting = weightingFunction.apply
class LocalAttention(nn.Module):
def __init__(self, inp_channels, out_channels, kH, kW):
super(LocalAttention, self).__init__()
self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.kH = kH
self.kW = kW
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x3 = self.conv3(x)
weight = f_similar(x1, x2, self.kH, self.kW)
weight = F.softmax(weight, -1)
out = f_weighting(x3, weight, self.kH, self.kW)
return out
class TorchLocalAttention(nn.Module):
def __init__(self, inp_channels, out_channels, kH, kW):
super(TorchLocalAttention, self).__init__()
self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.kH = kH
self.kW = kW
@staticmethod
def f_similar(x_theta, x_phi, kh, kw, casual_mask=False):
n, c, h, w = x_theta.size() # (N, inter_channels, H, W)
pad = (kh // 2, kw // 2)
x_theta = x_theta.permute(0, 2, 3, 1).contiguous()
x_theta = x_theta.view(n * h * w, 1, c)
x_phi = F.unfold(x_phi, kernel_size=(kh, kw), stride=1, padding=pad)
x_phi = x_phi.contiguous().view(n, c, kh * kw, h * w)
x_phi = x_phi.permute(0, 3, 1, 2).contiguous()
x_phi = x_phi.view(n * h * w, c, kh * kw)
out = x_theta @ x_phi
out = out.view(n, h, w, kh * kw)
if casual_mask:
out = out[..., :kh * kw // 2 + 1]
return out
@staticmethod
def f_weighting(x_theta, x_phi, kh, kw, casual_mask=False):
n, c, h, w = x_theta.size() # (N, inter_channels, H, W)
pad = (kh // 2, kw // 2)
x_theta = F.unfold(x_theta, kernel_size=(kh, kw), stride=1, padding=pad)
x_theta = x_theta.permute(0, 2, 1).contiguous()
x_theta = x_theta.view(n * h * w, c, kh * kw)
if casual_mask:
x_theta = x_theta[..., :kh * kw // 2 + 1]
x_phi = x_phi.view(n * h * w, kh * kw // 2 + 1, 1)
else:
x_phi = x_phi.view(n * h * w, kh * kw, 1)
out = torch.matmul(x_theta, x_phi)
out = out.squeeze(-1)
out = out.view(n, h, w, c)
out = out.permute(0, 3, 1, 2).contiguous()
return out
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x3 = self.conv3(x)
weight = self.f_similar(x1, x2, self.kH, self.kW)
weight = F.softmax(weight, -1)
out = self.f_weighting(x3, weight, self.kH, self.kW)
return out
......@@ -52,13 +52,13 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
query_layer / math.sqrt(query_layer.shape[-1]),
key_layer.transpose(-1, -2)
)
if log_attention_weights is not None:
attention_scores += log_attention_weights
if attention_mask.shape[-2] > 1: # if auto-regressive, skip
attention_scores = torch.mul(attention_scores, attention_mask) - \
10000.0 * (1.0 - attention_mask)
if log_attention_weights is not None:
attention_scores += log_attention_weights
attention_probs = F.softmax(attention_scores, dim=-1)
if attention_dropout is not None:
......@@ -179,33 +179,6 @@ class MLP(torch.nn.Module):
class BaseTransformerLayer(torch.nn.Module):
"""A single layer transformer for GPT2.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
output_layer_init_method: output layers (attention output and
mlp output) initialization. If None,
use `init_method`.
"""
def __init__(
self,
hidden_size,
......
This diff is collapsed.
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.optim.lr_scheduler import _LRScheduler
import math
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate from start to zero along a cosine curve."""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5, auto_warmup_steps=50, auto_warmup_rate=0.05):
assert warmup_iter <= num_iters
self.optimizer = optimizer
self.start_lr = start_lr
self.warmup_iter = warmup_iter
self.init_step = last_iter
self.num_iters = last_iter + 1
self.end_iter = num_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None
self.decay_ratio = 1 / decay_ratio
self.auto_warmup_steps = auto_warmup_steps
self.auto_warmup_rate = auto_warmup_rate
self.step(self.num_iters)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(f'learning rate decaying style {self.decay_style}, ratio {self.decay_ratio}')
def get_lr(self):
if self.num_iters <= self.init_step + self.auto_warmup_steps:
return float(self.start_lr) * self.auto_warmup_rate
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * self.num_iters / self.warmup_iter
else:
if self.decay_style == self.DECAY_STYLES[0]:
return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/real_end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / real_end_iter)
return self.start_lr / self.decay_ratio * (
(math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
#TODO: implement exponential decay
return self.start_lr
else:
return self.start_lr
def step(self, step_num=None):
if step_num is None:
step_num = self.num_iters + 1
self.num_iters = step_num
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
sd = {
'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
'decay_style': self.decay_style,
'end_iter': self.end_iter,
'decay_ratio': self.decay_ratio
}
return sd
def load_state_dict(self, sd):
pass # disable this
# -*- encoding: utf-8 -*-
'''
@File : model_io.py
@Time : 2021/10/05 18:39:55
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import numpy as np
import mpu
from utils import print_rank_0
def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
if release:
d = 'release'
else:
d = '{:d}'.format(iteration)
if zero:
dp_rank = mpu.get_data_parallel_rank()
d += '_zero_dp_rank_{}'.format(dp_rank)
return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest')
def save_checkpoint(iteration, model, optimizer,
lr_scheduler, args):
"""Save a model checkpoint."""
if args.deepspeed:
save_ds_checkpoint(iteration, model, lr_scheduler, args)
else:
raise ValueError("training without deepspeed is not supported.")
# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def save_ds_checkpoint(iteration, model, lr_scheduler, args):
"""Save a model checkpoint."""
sd = {}
sd['iteration'] = iteration
if lr_scheduler is not None:
sd['client_lr_scheduler'] = lr_scheduler.state_dict()
# rng states.
if not args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd)
def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True):
os.makedirs(save_dir, exist_ok=True)
# Ensure tag is a string
tag = str(tag)
# Ensure checkpoint tag is consistent across ranks
model._checkpoint_tag_validation(tag)
# Real save via deepspeed
model._create_checkpoint_file(save_dir, tag, False)
model._save_checkpoint(save_dir, tag, client_state=client_state)
# Save latest checkpoint tag
if save_latest:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)
return True
def get_checkpoint_iteration(args):
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from random')
return 0, False, False
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
return iteration, release, True
def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=True):
"""Load a model checkpoint."""
iteration, release, success = get_checkpoint_iteration(args)
if not success:
return 0
checkpoint_name = get_checkpoint_name(args.load, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
sd = torch.load(checkpoint_name, map_location='cpu')
assert not args.do_train or args.deepspeed
if args.deepspeed:
module = model.module
else: # inference without deepspeed
module = model
# only load module, other hyperparameters are just for recording.
missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
if len(unexpected_keys) > 0:
print_rank_0(f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
if len(missing_keys) > 0:
if not args.do_train:
raise ValueError(f'Missing keys for inference: {missing_keys}.')
else: # new params
assert all(name.find('mixins')>0 for name in missing_keys)
module.reinit() # initialize mixins
model.optimizer.refresh_fp32_params() # restore fp32 weights from module
# Iterations.
if args.mode == 'finetune':
iteration = 0
elif args.mode == 'pretrain' and not args.no_load_rng: # rng states.
try:
random.setstate(sd['random_rng_state'])
np.random.set_state(sd['np_rng_state'])
torch.set_rng_state(sd['torch_rng_state'])
torch.cuda.set_rng_state(sd['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the random '
'state.'.format(checkpoint_name))
exit()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
del sd
return iteration
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment