Skip to content
Snippets Groups Projects
Commit 981fdabd authored by Ming Ding's avatar Ming Ding
Browse files

Merge branch 'glm_finetune' into finer-attn-hooks

parents 56936ad9 fa77bfc2
No related branches found
No related tags found
No related merge requests found
from .mlp_head import MLPHeadMixin
from .prompt_tuning import PrefixTuningMixin, PTuningV2Mixin
\ No newline at end of file
# -*- encoding: utf-8 -*-
'''
@File : mlp_head.py
@Time : 2021/12/12 20:44:09
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
class MLPHeadMixin(BaseMixin):
def __init__(self, hidden_size, *output_sizes, bias=True, activation_func=torch.nn.functional.relu, init_mean=0, init_std=0.005):
super().__init__()
self.activation_func = activation_func
last_size = hidden_size
self.layers = torch.nn.ModuleList()
for sz in output_sizes:
this_layer = torch.nn.Linear(last_size, sz, bias=bias)
last_size = sz
torch.nn.init.normal_(this_layer.weight, mean=init_mean, std=init_std)
self.layers.append(this_layer)
def final_forward(self, logits, **kw_args):
for i, layer in enumerate(self.layers):
if i > 0:
logits = self.activation_func(logits)
logits = layer(logits)
return logits
\ No newline at end of file
# -*- encoding: utf-8 -*-
'''
@File : prompt_tuning.py
@Time : 2021/12/12 20:45:18
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
from SwissArmyTransformer.mpu.transformer import standard_attention
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
class PrefixTuningMixin(BaseMixin):
def __init__(self, num_layers, hidden_size_per_attention_head, num_attention_heads, prefix_len):
super().__init__()
self.prefix = torch.nn.ParameterList([
torch.nn.Parameter(torch.randn(2, num_attention_heads, prefix_len, hidden_size_per_attention_head)*0.01)
for layer_id in range(num_layers)
])
self.prefix_len = prefix_len
@non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
prefix_k, prefix_v = self.prefix[kw_args['layer_id']]
b, nh, seq_len, hidden_size = k.shape
prefix_k = prefix_k.unsqueeze(0).expand(b, nh, -1, hidden_size)
prefix_v = prefix_v.unsqueeze(0).expand(b, nh, -1, hidden_size)
k = torch.cat((k, prefix_k), dim=2)
v = torch.cat((v, prefix_v), dim=2)
if mask.numel() > 1:
mask_prefixed = torch.ones(self.prefix_len, device=mask.device, dtype=mask.dtype)
mask_prefixed = mask_prefixed.expand(*(mask.size()[:-1]), -1)
mask = torch.cat((mask, mask_prefixed), dim=-1)
return old_impl(q, k, v, mask, dropout_fn, **kw_args)
PTuningV2Mixin = PrefixTuningMixin
\ No newline at end of file
...@@ -17,7 +17,7 @@ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear ...@@ -17,7 +17,7 @@ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
from SwissArmyTransformer.mpu.transformer import unscaled_init_method from SwissArmyTransformer.mpu.transformer import unscaled_init_method
from .base_model import BaseMixin from .base_model import BaseMixin
from .cached_autoregressive_model import CachedAutoregressiveMixin from .cached_autoregressive_model import CachedAutoregressiveMixin
from .finetune import *
class PositionEmbeddingMixin(BaseMixin): class PositionEmbeddingMixin(BaseMixin):
def __init__(self, additional_sequence_length, hidden_size, def __init__(self, additional_sequence_length, hidden_size,
......
# -*- encoding: utf-8 -*-
'''
@File : finetune_glm_sst2.py
@Time : 2021/12/12 20:53:28
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
from SwissArmyTransformer.data_utils.datasets import TSVDataset
import torch
import argparse
import numpy as np
from SwissArmyTransformer import mpu, get_args, get_tokenizer
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
from SwissArmyTransformer.training.deepspeed_training import training_main
from SwissArmyTransformer.data_utils import TSVDataset
from SwissArmyTransformer.model import GLMModel
from SwissArmyTransformer.mpu.transformer import standard_attention
from SwissArmyTransformer.model.mixins import MLPHeadMixin, PrefixTuningMixin
class ClassificationModel(GLMModel):
def __init__(self, args, transformer=None, parallel_output=True):
super().__init__(args, transformer=transformer, parallel_output=parallel_output)
self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))
def disable_untrainable_params(self):
self.transformer.word_embeddings.requires_grad_(False)
# for layer_id in range(len(self.transformer.layers)):
# self.transformer.layers[layer_id].requires_grad_(False)
def get_batch(data_iterator, args, timers):
# Items and their type.
keys = ['sentence', 'label']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['sentence'].long()
labels = data_b['label'].long()
batch_size, seq_length = tokens.size()
position_ids = torch.zeros(2, seq_length, device=tokens.device, dtype=torch.long)
torch.arange(0, seq_length, out=position_ids[0, :seq_length])
position_ids = position_ids.unsqueeze(0)
attention_mask = torch.ones((batch_size, 1, seq_length, seq_length), device=tokens.device)
attention_mask[...,:seq_length] -= (tokens==-1).view(batch_size, 1, 1, seq_length).float()
# Convert
if args.fp16:
attention_mask = attention_mask.half()
return tokens, labels, attention_mask, position_ids, (tokens!=-1)
def forward_step(data_iterator, model, args, timers):
"""Forward step."""
# Get the batch.
timers('batch generator').start()
tokens, labels, attention_mask, position_ids, loss_mask = get_batch(
data_iterator, args, timers)
timers('batch generator').stop()
logits, *mems = model(tokens, position_ids, attention_mask)
pred = ((logits.contiguous().float().squeeze(-1)) * loss_mask).sum(dim=-1) / loss_mask.sum(dim=-1)
loss = torch.nn.functional.binary_cross_entropy_with_logits(
pred,
labels.float()
)
acc = ((pred > 0.).long() == labels).sum() / labels.numel()
return loss, {'acc': acc}
def create_dataset_function(path, args):
tokenizer = get_tokenizer()
def process_fn(row):
sentence, label = tokenizer._encode(row[0]), int(row[1])
sentence = [tokenizer.get_command('ENC').Id] + sentence + [tokenizer.get_command('eos').Id]
if len(sentence) >= args.sample_length:
sentence = sentence[:args.sample_length]
else:
sentence.extend([-1] * (args.sample_length-len(sentence)))
return {'sentence': np.array(sentence, dtype=np.int64), 'label': label}
return TSVDataset(path, process_fn, with_heads=True)
if __name__ == '__main__':
py_parser = argparse.ArgumentParser(add_help=False)
py_parser.add_argument('--new_hyperparam', type=str, default=None)
py_parser.add_argument('--sample_length', type=int, default=80)
py_parser.add_argument('--prefix_len', type=int, default=16)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
# from cogdata.utils.ice_tokenizer import get_tokenizer as get_ice
# tokenizer = get_tokenizer(args=args, outer_tokenizer=get_ice())
training_main(args, model_cls=ClassificationModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
{
"train_micro_batch_size_per_gpu":64,
"gradient_accumulation_steps": 1,
"steps_per_print": 10,
"gradient_clipping": 0.1,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 400,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00001,
"betas": [
0.9,
0.95
],
"eps": 1e-8,
"weight_decay": 0
}
},
"activation_checkpointing": {
"partition_activations": false,
"contiguous_memory_optimization": false
},
"wall_clock_breakdown": false
}
#! /bin/bash
# Change for multinode config
CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm
NUM_WORKERS=1
NUM_GPUS_PER_WORKER=1
MP_SIZE=1
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
main_dir=$(dirname $script_dir)
source $main_dir/config/model_glm_roberta_large.sh
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="hostfile"
HOST_FILE_PATH="hostfile_single"
en_data="/dataset/fd5061f6/english_data/glue_data/SST-2/train.tsv"
eval_data="/dataset/fd5061f6/english_data/glue_data/SST-2/dev.tsv"
config_json="$script_dir/ds_config_ft.json"
gpt_options=" \
--experiment-name finetune-glm-sst2 \
--model-parallel-size ${MP_SIZE} \
--mode finetune \
--train-iters 6000 \
--resume-dataloader \
$MODEL_ARGS \
--train-data ${en_data} \
--valid-data ${eval_data} \
--distributed-backend nccl \
--lr-decay-style cosine \
--warmup .02 \
--checkpoint-activations \
--fp16 \
--save-interval 6000 \
--eval-interval 100 \
--save /root/checkpoints \
--split 1 \
--strict-eval \
--eval-batch-size 8
"
# --load /root/checkpoints/pretrain-bert-mid-std-fulltrain12-02-06-10
# \ --sandwich-ln
# --split 949,50,1 \
# --load /root/checkpoints/pretrain-bert-mid11-28-15-38 \
gpt_options="${gpt_options}
--deepspeed \
--deepspeed_config ${config_json} \
"
run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_glm_sst2.py $@ ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}
set +x
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment