diff --git a/finetune_retrieval.py b/finetune_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1282a4d262d42def7b14153d3378bd104b8067c1
--- /dev/null
+++ b/finetune_retrieval.py
@@ -0,0 +1,233 @@
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import torch.nn.functional as F
+import argparse
+import numpy as np
+
+import mpu
+from arguments import get_args
+from model.retrieval_model import RetrievalModel
+from training.deepspeed_training import training_main
+from data_utils import BinaryDataset
+from tokenization import get_tokenizer
+from tokenization.cogview import TextCodeTemplate
+
+def get_masks_and_position_ids(data,
+                               n_pads,
+                               attention_mask=None, args=None):
+    # Extract batch size and sequence length.
+    batch_size, seq_length = data.size()
+    
+    layout = args.layout
+    
+    # Attention mask
+    if attention_mask is None:
+        attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
+        attention_mask.tril_()
+        
+        attention_mask[:, :args.layout[0] - 5, args.layout[0] - 5:] = 0 # 0 attention for txt to img
+        attention_mask[:, args.layout[0] - 5:-2, :args.layout[0] - 5] = 0 # 0 attention for img to txt
+        
+        attention_mask[:, -1, args.layout[0] - 5:-1] = 0 # attention for txt retrieval
+        attention_mask[:, -2, :args.layout[0] - 5] = 0 # attention for img retrieval
+        
+        for i in range(batch_size): # 0 attention for padding
+            attention_mask[i, :n_pads[i], :] = 0
+            attention_mask[i, :, :n_pads[i]] = 0
+        
+        attention_mask.unsqueeze_(1)
+    
+    # Position ids.
+    position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long,
+                               device=data.device)
+    for i in range(batch_size):
+        torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]],
+            dtype=torch.long, device=data.device)
+    position_ids[:, -1] = 0
+    position_ids[:, -2] = 1
+
+    return attention_mask, position_ids
+
+
+def get_batch(data_iterator, args, timers):
+    # Items and their type.
+    keys = ['text', 'n_pads']
+    datatype = torch.int64
+    
+    # Broadcast data.
+    timers('data loader').start()
+    if data_iterator is not None:
+        data = next(data_iterator)
+    else:
+        data = None
+    timers('data loader').stop()
+    
+    data_b = mpu.broadcast_data(keys, data, datatype)
+    # Unpack.
+    tokens = data_b['text'].long()
+    n_pads = data_b['n_pads'].long()
+    
+    attention_mask = None
+    
+    # Get the masks and postition ids.
+    attention_mask, position_ids = get_masks_and_position_ids(
+        tokens,
+        n_pads,
+        attention_mask=attention_mask,
+        args=args
+        )
+    # Convert
+    if args.fp16:
+        attention_mask = attention_mask.half()
+        
+    return tokens, attention_mask, position_ids, n_pads
+
+def ensure_divisibility(numerator, denominator):
+    """Ensure that numerator is divisible by the denominator."""
+    assert numerator % denominator == 0, '{} is not divisible by {}'.format(
+        numerator, denominator)
+
+
+def divide(numerator, denominator):
+    """Ensure that numerator is divisible by the denominator and return
+    the division value."""
+    ensure_divisibility(numerator, denominator)
+    return numerator // denominator
+
+def parallel_contrastive_loss(lvec, rvec, args):
+    # if mpu.get_data_parallel_rank() == 0:
+    #     print('lvec', lvec)
+    #     print('rvec', rvec)
+    device = args.device
+    temp = args.retrieval_temp
+    lvec = lvec * np.exp(temp)
+    
+    rank = mpu.get_data_parallel_rank()
+    world_size = mpu.get_data_parallel_world_size()
+    
+    batch_size_per_partition, hidden_size = lvec.shape
+    batch_size = batch_size_per_partition * world_size
+    split_start, split_end = rank * batch_size_per_partition, (rank + 1) * batch_size_per_partition
+    arange_1d = torch.arange(0, batch_size, device=lvec.device)
+    
+    # Broadcast vecs
+    broadcast_lvec, broadcast_rvec = lvec.detach(), rvec.detach()
+    parallel_lvec = [torch.zeros(batch_size_per_partition, hidden_size, device=torch.device(device)).half() for _ in range(world_size)]
+    parallel_rvec = [torch.zeros(batch_size_per_partition, hidden_size, device=torch.device(device)).half() for _ in range(world_size)]
+    
+    torch.distributed.all_gather(parallel_lvec, broadcast_lvec, group=mpu.get_data_parallel_group())
+    torch.distributed.all_gather(parallel_rvec, broadcast_rvec, group=mpu.get_data_parallel_group())
+    
+    parallel_lvec = torch.cat(parallel_lvec, dim=0)
+    parallel_rvec = torch.cat(parallel_rvec, dim=0)
+    
+    # if mpu.get_data_parallel_rank() == 0:
+    #     print('parallel vec', parallel_lvec, parallel_rvec)
+    
+    # Calculate logits
+    local_local_logits = lvec @ rvec.permute(1, 0)
+    local_dist_logits = lvec @ parallel_rvec.permute(1, 0)
+    dist_local_logits = parallel_lvec @ rvec.permute(1, 0)
+    
+    # if mpu.get_data_parallel_rank() == 0:
+    #     print('local_dist_logits', local_dist_logits)
+    
+    # Broadcast logits
+    broadcast_local_dist_logits = local_dist_logits.detach()
+    parallel_logits = [torch.zeros(batch_size_per_partition, batch_size, device=torch.device(device)).half() for _ in range(world_size)]
+    torch.distributed.all_gather(parallel_logits, broadcast_local_dist_logits, group=mpu.get_data_parallel_group())
+    parallel_logits = torch.cat(parallel_logits, dim=0)
+    
+    # Fill in local logits for backward
+    parallel_logits[split_start:split_end, :] = local_dist_logits
+    parallel_logits[:, split_start:split_end] = dist_local_logits
+    parallel_logits[split_start:split_end, split_start:split_end] = local_local_logits    
+    
+    predicted_logits = parallel_logits[arange_1d, arange_1d]
+    
+    # Calculate left2right loss
+    left_logits_max = torch.max(parallel_logits, dim=1)[0]
+    left_logits = parallel_logits.sub(left_logits_max.unsqueeze(dim=1))
+    left_exp_logits = left_logits.exp()
+    left_sum_exp_logits = left_exp_logits.sum(dim=1)
+    left_loss = torch.log(left_sum_exp_logits) - predicted_logits # Loss = log(sum(exp(logits))) - predicted-logit.
+    left_loss = left_loss.sum()
+    
+    # Calculate right2left loss
+    parallel_logits_t = parallel_logits.permute(1, 0)
+    right_logits_max = torch.max(parallel_logits_t, dim=1)[0]
+    right_logits = parallel_logits_t.sub(right_logits_max.unsqueeze(dim=1))
+    right_exp_logits = right_logits.exp()
+    right_sum_exp_logits = right_exp_logits.sum(dim=1)
+    right_loss = torch.log(right_sum_exp_logits) - predicted_logits # Loss = log(sum(exp(logits))) - predicted-logit.
+    right_loss = right_loss.sum()
+    
+    total_loss = (left_loss + right_loss) / 2
+    
+    # if mpu.get_data_parallel_rank() == 0:
+    #     print('parallel', parallel_logits)
+    #     print('loss', total_loss, left_loss, right_loss)
+    
+    return total_loss, left_loss, right_loss
+
+def forward_step(data_iterator, model, args, timers):
+    """Forward step."""
+    
+    # Get the batch.
+    timers('batch generator').start()
+    tokens, attention_mask, position_ids, n_pads = get_batch(
+        data_iterator, args, timers)
+    timers('batch generator').stop()
+    
+    # Forward model.
+    (txt_vecs, img_vecs), *mems = model(tokens, position_ids, attention_mask)
+    
+    # L2 Normalize
+    txt_vecs = txt_vecs / txt_vecs.pow(2).sum(dim=1).sqrt().unsqueeze(1)
+    img_vecs = img_vecs / img_vecs.pow(2).sum(dim=1).sqrt().unsqueeze(1)
+    
+    loss, txt2img_loss, img2txt_loss = parallel_contrastive_loss(txt_vecs, img_vecs, args)
+    # loss, txt2img_loss, img2txt_loss = torch.tensor(0.), torch.tensor(0.), torch.tensor(0.)
+    
+    return loss, {'txt2img_loss': txt2img_loss, 'img2txt_loss': img2txt_loss}
+    
+def create_dataset_function(path, args):
+    tokenizer = get_tokenizer(args)
+    # layout = args.layout
+    layout = [64, 1088]
+    def process_fn(row):
+        row = row.astype(np.int64)
+        codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))]
+        
+        text = row[:layout[0]]
+        text = text[text > 0][:layout[0] - 6]
+        n_pad = layout[0] - 6 - len(text)
+        merged = TextCodeTemplate(text, codes[0], tokenizer)
+        parts = [
+            np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
+            merged, 
+            np.array([tokenizer['[POS0]'], tokenizer['[POS1]']], dtype=np.int64)
+        ]
+        ret = np.concatenate(parts, axis=0)
+        return {'text': ret,
+                'n_pads': n_pad}
+    return BinaryDataset(path, process_fn, length_per_sample=layout[-1])
+    
+
+if __name__ == '__main__':
+    py_parser = argparse.ArgumentParser(add_help=False)
+    
+    RetrievalModel.add_model_specific_args(py_parser)
+    
+    known, args_list = py_parser.parse_known_args()
+    
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+
+    args.layout = [int(x) for x in args.layout.split(',')]
+    
+    training_main(args, model_cls=RetrievalModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
\ No newline at end of file
diff --git a/model/mixins.py b/model/mixins.py
index 78befb8b27637deb7bf7c98bdcfde408eb28b483..618e7910d924707a1967008350a3c2821ea3e6d0 100644
--- a/model/mixins.py
+++ b/model/mixins.py
@@ -67,3 +67,40 @@ class AttentionMixin(BaseMixin):
             self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
             self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
             self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
+
+class ParallelLinearMixin(BaseMixin):
+    def __init__(self, input_size, output_size, bias=True,
+                 init_method=torch.nn.init.xavier_normal_, stride=1,
+                 keep_master_weight_for_test=False):
+        super(ParallelLinearMixin, self).__init__()
+        self.input_size = input_size
+        self.output_size = output_size
+        
+        self.parallel_linear = ColumnParallelLinear(
+            input_size, output_size, bias=bias, gather_output=True,
+            init_method=init_method, stride=stride,
+            keep_master_weight_for_test=keep_master_weight_for_test)
+
+    def forward(self, input_):
+        return self.parallel_linear(input_)
+
+class ParallelDoubleLayerLinearMixin(BaseMixin):
+    def __init__(self, input_size, hidden_size, output_size, bias=True,
+                 init_method=torch.nn.init.xavier_normal_, stride=1,
+                 keep_master_weight_for_test=False):
+        super(ParallelDoubleLayerLinearMixin, self).__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.output_size = output_size
+        
+        self.column_parallel_linear = ColumnParallelLinear(
+            input_size, hidden_size, bias=bias, gather_output=False,
+            init_method=init_method, stride=stride,
+            keep_master_weight_for_test=keep_master_weight_for_test)
+        self.row_parallel_linear = RowParallelLinear(
+            hidden_size, output_size, bias=bias,
+            init_method=init_method, stride=stride,
+            keep_master_weight_for_test=keep_master_weight_for_test)
+    
+    def forward(self, input_):
+        return self.row_parallel_linear(self.column_parallel_linear(input_))
\ No newline at end of file
diff --git a/model/retrieval_model.py b/model/retrieval_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..96fea46e790cb3b43ea48308cb2b26bbce4a0fdc
--- /dev/null
+++ b/model/retrieval_model.py
@@ -0,0 +1,65 @@
+# -*- encoding: utf-8 -*-
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .mixins import PositionEmbeddingMixin, AttentionMixin, ParallelLinearMixin
+
+from mpu.transformer import split_tensor_along_last_dim
+from mpu.utils import sqrt
+from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
+
+
+class RetrievalModel(BaseModel):
+    def __init__(self, args, transformer=None):
+        super().__init__(args, transformer=transformer)
+        self.layout = args.layout
+        self.txt_img_split = args.txt_img_split
+        
+        self.mixins.append(PositionEmbeddingMixin(
+            2, args.hidden_size
+        ))
+        self.mixins.extend([
+            ParallelLinearMixin(
+                args.hidden_size, args.retrieval_size),
+            ParallelLinearMixin(
+                args.hidden_size, args.retrieval_size)
+            ])
+    
+    def reinit(self):
+        pass
+    
+    def position_embedding_forward(self, position_ids, *other_tensors):
+        position_embeddings = torch.cat(
+                (
+                    self.transformer.position_embeddings(position_ids[:, :-2]),
+                    self.mixins[0].position_embeddings(position_ids[:, -2:])
+                ),
+                dim=-2
+            )
+        return position_embeddings
+    
+    def final_forward(self, logits, *other_tensors):
+        txt_logits = logits[:, -1, :]
+        img_logits = logits[:, -2, :]
+        return (self.mixins[1](txt_logits), self.mixins[2](img_logits))
+    
+    def disable_untrainable_params(self):
+        pass
+    
+    @classmethod
+    def add_model_specific_args(cls, parser):
+        group = parser.add_argument_group('RetrievalModel', 'retrieval model configurations')
+        group.add_argument('--txt-img-split', action='store_true')
+        group.add_argument('--retrieval-temp', type=int, default=0)
+        group.add_argument('--retrieval-mode', type=str, default='txt2img',
+                            choices=['txt2img', 'img2txt', 'symmetric'])
+        group.add_argument('--retrieval-hidden-size', type=int, default=2048)
+        group.add_argument('--retrieval-size', type=int, default=1024)
+        group.add_argument("--layout", type=str, default='64,1088')
+        return parser
\ No newline at end of file
diff --git a/scripts/finetune_into_retrieval.sh b/scripts/finetune_into_retrieval.sh
new file mode 100755
index 0000000000000000000000000000000000000000..c0a38d2af5091cede59757fe0f47e1103c77a389
--- /dev/null
+++ b/scripts/finetune_into_retrieval.sh
@@ -0,0 +1,59 @@
+#! /bin/bash
+
+# Change for multinode config
+
+NUM_WORKERS=1
+NUM_GPUS_PER_WORKER=8
+MP_SIZE=1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+HOST_FILE_PATH="hostfile_single"
+
+# full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_1/ali/ali.bin.part_0.cogdata"
+
+config_json="$script_dir/ds_config_zero.json"
+gpt_options=" \
+       --experiment-name finetune-retrieval-test \
+       --tokenizer-type cogview \
+       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
+       --model-parallel-size ${MP_SIZE} \
+       --mode finetune \
+       --batch-size 4 \
+       --num-layers 48 \
+       --hidden-size 2560 \
+       --num-attention-heads 40 \
+       --train-iters 200000 \
+       --resume-dataloader \
+       --train-data ${small_data} \
+       --split 949,50,1 \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .1 \
+       --checkpoint-activations \
+       --max-sequence-length 1089 \
+       --sandwich-ln \
+       --fp16 \
+       --save-interval 2000 \
+       --eval-interval 1000 \
+       --save $main_dir/checkpoints \
+       --load /workspace/dm/SwissArmyTransformer/pretrained/cogview/cogview-base
+"
+       # --load pretrained/cogview/cogview-base
+
+
+gpt_options="${gpt_options}
+       --deepspeed \
+       --deepspeed_config ${config_json} \
+"
+              
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_retrieval.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py
index 90190a6f95eb775808cc058b9f912d3a9b6555d2..2d74dbf8fa7ffd278d109eba03853cd3ad895f6a 100644
--- a/training/deepspeed_training.py
+++ b/training/deepspeed_training.py
@@ -335,7 +335,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
 
         # Check nan or inf in forward, preventing it from interfering loss scaler,
         # and all reduce metrics by the way
-        loss_checker = lm_loss.detach()
+        loss_checker = lm_loss.clone().detach()
         for name in metrics:
             metrics[name] = metrics[name].detach()
             torch.distributed.all_reduce(metrics[name].data)