diff --git a/SwissArmyTransformer/model/finetune/__init__.py b/SwissArmyTransformer/model/finetune/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f2cc6a0915d3df1762edd4f202b2f96b81a5609
--- /dev/null
+++ b/SwissArmyTransformer/model/finetune/__init__.py
@@ -0,0 +1,2 @@
+from .mlp_head import MLPHeadMixin
+from .prompt_tuning import PrefixTuningMixin, PTuningV2Mixin
\ No newline at end of file
diff --git a/SwissArmyTransformer/model/finetune/mlp_head.py b/SwissArmyTransformer/model/finetune/mlp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e25d661d0d36a24887614dbcbb07faaa445956e2
--- /dev/null
+++ b/SwissArmyTransformer/model/finetune/mlp_head.py
@@ -0,0 +1,36 @@
+
+# -*- encoding: utf-8 -*-
+'''
+@File    :   mlp_head.py
+@Time    :   2021/12/12 20:44:09
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+import torch
+from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
+
+class MLPHeadMixin(BaseMixin):
+    def __init__(self, hidden_size, *output_sizes, bias=True, activation_func=torch.nn.functional.relu, init_mean=0, init_std=0.005):
+        super().__init__()
+        self.activation_func = activation_func
+        last_size = hidden_size
+        self.layers = torch.nn.ModuleList()
+        for sz in output_sizes:
+            this_layer = torch.nn.Linear(last_size, sz, bias=bias)
+            last_size = sz
+            torch.nn.init.normal_(this_layer.weight, mean=init_mean, std=init_std)
+            self.layers.append(this_layer)
+
+    def final_forward(self, logits, **kw_args):
+        for i, layer in enumerate(self.layers):
+            if i > 0:
+                logits = self.activation_func(logits)
+            logits = layer(logits)
+        return logits
\ No newline at end of file
diff --git a/SwissArmyTransformer/model/finetune/prompt_tuning.py b/SwissArmyTransformer/model/finetune/prompt_tuning.py
new file mode 100644
index 0000000000000000000000000000000000000000..51cd1326b2a42cbcb89a327ab3b15892e3c0a2c1
--- /dev/null
+++ b/SwissArmyTransformer/model/finetune/prompt_tuning.py
@@ -0,0 +1,45 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   prompt_tuning.py
+@Time    :   2021/12/12 20:45:18
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+
+from SwissArmyTransformer.mpu.transformer import standard_attention
+from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
+
+
+class PrefixTuningMixin(BaseMixin):
+    def __init__(self, num_layers, hidden_size_per_attention_head, num_attention_heads, prefix_len):
+        super().__init__()
+        self.prefix = torch.nn.ParameterList([
+            torch.nn.Parameter(torch.randn(2, num_attention_heads, prefix_len, hidden_size_per_attention_head)*0.01)
+            for layer_id in range(num_layers)
+        ])
+        self.prefix_len = prefix_len
+
+    @non_conflict
+    def attention_fn(self, q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
+        prefix_k, prefix_v = self.prefix[kw_args['layer_id']]
+
+        b, nh, seq_len, hidden_size = k.shape
+        prefix_k = prefix_k.unsqueeze(0).expand(b, nh, -1, hidden_size)
+        prefix_v = prefix_v.unsqueeze(0).expand(b, nh, -1, hidden_size)
+
+        k = torch.cat((k, prefix_k), dim=2)
+        v = torch.cat((v, prefix_v), dim=2)
+        if mask.numel() > 1:
+            mask_prefixed = torch.ones(self.prefix_len, device=mask.device, dtype=mask.dtype)
+            mask_prefixed = mask_prefixed.expand(*(mask.size()[:-1]), -1)
+            mask = torch.cat((mask, mask_prefixed), dim=-1)
+        return old_impl(q, k, v, mask, dropout_fn, **kw_args)
+
+PTuningV2Mixin = PrefixTuningMixin
\ No newline at end of file
diff --git a/SwissArmyTransformer/model/mixins.py b/SwissArmyTransformer/model/mixins.py
index 2c6b09953615c5c9dbb0bb7b020a585bcccf6324..d1fcf89f7311f2f52835e54285dadfd016714fde 100644
--- a/SwissArmyTransformer/model/mixins.py
+++ b/SwissArmyTransformer/model/mixins.py
@@ -17,7 +17,7 @@ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
 from SwissArmyTransformer.mpu.transformer import unscaled_init_method
 from .base_model import BaseMixin
 from .cached_autoregressive_model import CachedAutoregressiveMixin
-
+from .finetune import *
 
 class PositionEmbeddingMixin(BaseMixin):
     def __init__(self, additional_sequence_length, hidden_size,
diff --git a/examples/glm/finetune_glm_sst2.py b/examples/glm/finetune_glm_sst2.py
new file mode 100755
index 0000000000000000000000000000000000000000..14ef44d9cfa606d78fa9ea461f609453b24d8e73
--- /dev/null
+++ b/examples/glm/finetune_glm_sst2.py
@@ -0,0 +1,109 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   finetune_glm_sst2.py
+@Time    :   2021/12/12 20:53:28
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+from SwissArmyTransformer.data_utils.datasets import TSVDataset
+import torch
+import argparse
+import numpy as np
+
+from SwissArmyTransformer import mpu, get_args, get_tokenizer
+from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
+from SwissArmyTransformer.training.deepspeed_training import training_main
+from SwissArmyTransformer.data_utils import TSVDataset
+from SwissArmyTransformer.model import GLMModel
+from SwissArmyTransformer.mpu.transformer import standard_attention
+from SwissArmyTransformer.model.mixins import MLPHeadMixin, PrefixTuningMixin
+
+class ClassificationModel(GLMModel):
+    def __init__(self, args, transformer=None, parallel_output=True):
+        super().__init__(args, transformer=transformer, parallel_output=parallel_output)
+        self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
+        self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))
+    def disable_untrainable_params(self):
+        self.transformer.word_embeddings.requires_grad_(False)
+        # for layer_id in range(len(self.transformer.layers)):
+        #     self.transformer.layers[layer_id].requires_grad_(False)
+    
+def get_batch(data_iterator, args, timers):
+    # Items and their type.
+    keys = ['sentence', 'label']
+    datatype = torch.int64
+
+    # Broadcast data.
+    timers('data loader').start()
+    if data_iterator is not None:
+        data = next(data_iterator)
+    else:
+        data = None
+    timers('data loader').stop()
+    data_b = mpu.broadcast_data(keys, data, datatype)
+    # Unpack.
+    tokens = data_b['sentence'].long()
+    labels = data_b['label'].long()
+    batch_size, seq_length = tokens.size()
+    
+    position_ids = torch.zeros(2, seq_length, device=tokens.device, dtype=torch.long)
+    torch.arange(0, seq_length, out=position_ids[0, :seq_length])
+    position_ids = position_ids.unsqueeze(0)
+    
+    attention_mask = torch.ones((batch_size, 1, seq_length, seq_length), device=tokens.device)
+
+    attention_mask[...,:seq_length] -= (tokens==-1).view(batch_size, 1, 1, seq_length).float()
+    # Convert
+    if args.fp16:
+        attention_mask = attention_mask.half()
+    return tokens, labels, attention_mask, position_ids, (tokens!=-1)
+
+
+def forward_step(data_iterator, model, args, timers):
+    """Forward step."""
+
+    # Get the batch.
+    timers('batch generator').start()
+    tokens, labels, attention_mask, position_ids, loss_mask = get_batch(
+        data_iterator, args, timers)
+    timers('batch generator').stop()
+
+    logits, *mems = model(tokens, position_ids, attention_mask)
+    pred = ((logits.contiguous().float().squeeze(-1)) * loss_mask).sum(dim=-1) / loss_mask.sum(dim=-1)
+    loss = torch.nn.functional.binary_cross_entropy_with_logits(
+        pred, 
+        labels.float()
+        )
+    acc = ((pred > 0.).long() == labels).sum() / labels.numel()
+    return loss, {'acc': acc}
+
+def create_dataset_function(path, args):
+    tokenizer = get_tokenizer()
+    def process_fn(row):
+        sentence, label = tokenizer._encode(row[0]), int(row[1])
+        sentence = [tokenizer.get_command('ENC').Id] + sentence + [tokenizer.get_command('eos').Id]
+        if len(sentence) >= args.sample_length:
+            sentence = sentence[:args.sample_length]
+        else:
+            sentence.extend([-1] * (args.sample_length-len(sentence)))
+        return {'sentence': np.array(sentence, dtype=np.int64), 'label': label}
+    return TSVDataset(path, process_fn, with_heads=True)
+
+if __name__ == '__main__':    
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--new_hyperparam', type=str, default=None)
+    py_parser.add_argument('--sample_length', type=int, default=80)
+    py_parser.add_argument('--prefix_len', type=int, default=16)
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    # from cogdata.utils.ice_tokenizer import get_tokenizer as get_ice
+    # tokenizer = get_tokenizer(args=args, outer_tokenizer=get_ice())
+    training_main(args, model_cls=ClassificationModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
diff --git a/examples/glm/scripts/ds_config_ft.json b/examples/glm/scripts/ds_config_ft.json
new file mode 100755
index 0000000000000000000000000000000000000000..17effe048977096633ce31ef75a1cd0d9b2d813b
--- /dev/null
+++ b/examples/glm/scripts/ds_config_ft.json
@@ -0,0 +1,30 @@
+{
+  "train_micro_batch_size_per_gpu":64,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 10,
+  "gradient_clipping": 0.1,
+  "fp16": {
+    "enabled": true,
+    "loss_scale": 0,
+    "loss_scale_window": 400,
+    "hysteresis": 2,
+    "min_loss_scale": 1
+  },
+  "optimizer": {
+    "type": "Adam",
+    "params": {
+      "lr": 0.00001,
+      "betas": [
+        0.9,
+        0.95
+      ],
+      "eps": 1e-8,
+      "weight_decay": 0
+    }
+  },
+  "activation_checkpointing": {
+    "partition_activations": false,
+    "contiguous_memory_optimization": false
+  },
+  "wall_clock_breakdown": false
+}
diff --git a/examples/glm/scripts/finetune_sst2.sh b/examples/glm/scripts/finetune_sst2.sh
new file mode 100755
index 0000000000000000000000000000000000000000..4e4809a07557b96afa37b00e1ea74a5830d8608e
--- /dev/null
+++ b/examples/glm/scripts/finetune_sst2.sh
@@ -0,0 +1,62 @@
+#! /bin/bash
+
+# Change for multinode config
+CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm
+
+NUM_WORKERS=1
+NUM_GPUS_PER_WORKER=1
+MP_SIZE=1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+source $main_dir/config/model_glm_roberta_large.sh
+
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+HOST_FILE_PATH="hostfile_single"
+
+en_data="/dataset/fd5061f6/english_data/glue_data/SST-2/train.tsv"
+eval_data="/dataset/fd5061f6/english_data/glue_data/SST-2/dev.tsv"
+
+
+config_json="$script_dir/ds_config_ft.json"
+gpt_options=" \
+       --experiment-name finetune-glm-sst2 \
+       --model-parallel-size ${MP_SIZE} \
+       --mode finetune \
+       --train-iters 6000 \
+       --resume-dataloader \
+       $MODEL_ARGS \
+       --train-data ${en_data} \
+       --valid-data ${eval_data} \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .02 \
+       --checkpoint-activations \
+       --fp16 \
+       --save-interval 6000 \
+       --eval-interval 100 \
+       --save /root/checkpoints \
+       --split 1 \
+       --strict-eval \
+       --eval-batch-size 8 
+"
+       # --load  /root/checkpoints/pretrain-bert-mid-std-fulltrain12-02-06-10
+       #  \       --sandwich-ln
+       # --split 949,50,1 \
+       # --load /root/checkpoints/pretrain-bert-mid11-28-15-38 \
+
+
+
+gpt_options="${gpt_options}
+       --deepspeed \
+       --deepspeed_config ${config_json} \
+"
+              
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_glm_sst2.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x