diff --git a/inference_object.py b/inference_object.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..576635af756dbdf36d4a99ecd6edcda5343fda3a 100644
--- a/inference_object.py
+++ b/inference_object.py
@@ -0,0 +1,110 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   inference_cogview.py
+@Time    :   2021/10/09 19:41:58
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import argparse
+
+from arguments import get_args
+from model.cached_autoregressive_model import CachedAutoregressiveModel
+from model.cached_object_model import CachedObjectModel
+from model.ObjectModel import ObjectModel
+from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from tokenization import get_tokenizer
+from generation.sampling_strategies import BaseStrategy
+from generation.autoregressive_sampling import filling_sequence
+from generation.utils import timed_name, save_multiple_images, generate_continually
+
+
+def main(args):
+    initialize_distributed(args)
+    tokenizer = prepare_tokenizer(args)
+    # build model
+    model = CachedObjectModel(args)
+    if args.fp16:
+        model = model.half()
+    model = model.to(args.device)
+    load_checkpoint(model, args)
+    set_random_seed(args.seed)
+    model.eval()
+
+    # define function for each query
+    invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
+    strategy = BaseStrategy(invalid_slices,
+                            temperature=args.temperature, topk=args.top_k)
+
+    def process(raw_text):
+        if args.with_id:
+            query_id, raw_text = raw_text.split('\t')
+        print('raw text: ', raw_text)
+        raw_text = raw_text.split(' ')
+        objects = raw_text[1:]
+        seq = tokenizer.parse_query(f"[ROI1] {raw_text[0]}", img_size=args.img_size)
+        for i in range(len(objects)//5):
+            seq.append(tokenizer['POS0'])
+            seq.extend(objects[i*5:i*5+4] + args.old_token_num)
+            seq.extend(tokenizer.EncodeAsIds(objects[i*5 + 4]))
+        seq.extend(tokenizer.parse_query('[BASE] [BOI1] [MASK]*1024', img_size=args.img_size))
+
+        if len(seq) > 1271:
+            raise ValueError('text too long.')
+        # calibrate text length
+        txt_len = seq.index(tokenizer['[BASE]'])
+        log_attention_weights = torch.zeros(len(seq), len(seq),
+                                            device=args.device, dtype=torch.half if args.fp16 else torch.float32)
+        log_attention_weights[txt_len + 2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4  # TODO args
+        # generation
+        seq = torch.cuda.LongTensor(seq, device=args.device)
+        mbz = args.max_inference_batch_size
+        assert args.batch_size < mbz or args.batch_size % mbz == 0
+        output_list = []
+        for tim in range(max(args.batch_size // mbz, 1)):
+            output_list.append(
+                filling_sequence(model, seq.clone(),
+                                 batch_size=min(args.batch_size, mbz),
+                                 strategy=strategy,
+                                 log_attention_weights=log_attention_weights
+                                 )
+            )
+        output_tokens = torch.cat(output_list, dim=0)
+        # decoding
+        imgs, txts = [], []
+        for seq in output_tokens:
+            txt_len = seq.index(tokenizer['[BASE]'])
+            seq = seq[txt_len:]
+            _, decoded_imgs = tokenizer.DecodeIds(seq.tolist())
+            imgs.append(decoded_imgs[-1])  # only the last image (target)
+        # save
+        if args.with_id:
+            full_path = os.path.join(args.output_path, query_id)
+            os.makedirs(full_path, exist_ok=True)
+            save_multiple_images(imgs, full_path, False)
+        else:
+            prefix = raw_text.replace('/', '')[:20]
+            full_path = timed_name(prefix, '.jpg', args.output_path)
+            save_multiple_images(imgs, full_path, True)
+
+    os.makedirs(args.output_path, exist_ok=True)
+    generate_continually(process, args.input_source)
+
+
+if __name__ == "__main__":
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--full-query', action='store_true')
+    py_parser.add_argument('--img-size', type=int, default=256)
+    ObjectModel.add_model_specific_args(py_parser)
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+
+    with torch.no_grad():
+        main(args)
\ No newline at end of file
diff --git a/model/ObjectModel.py b/model/ObjectModel.py
index c236f717da63aa2431c1b7b86e9c19d9355608f2..1000fe1f7594cc44c9cd65e9387e0f03d9b18c37 100644
--- a/model/ObjectModel.py
+++ b/model/ObjectModel.py
@@ -3,37 +3,55 @@ import torch.nn.functional as F
 
 
 from .base_model import BaseModel
-from .mixins import PositionEmbeddingMixin, AttentionMixin
+from .mixins import PositionEmbeddingMixin, AttentionMixin, WordEmebeddingMixin
 
 from mpu.transformer import split_tensor_along_last_dim
 from mpu.local_attention_function import f_similar, f_weighting
 from mpu.utils import sqrt
 from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
+from mpu.mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
 
 class ObjectModel(BaseModel):
     def __init__(self, args, transformer=None):
         super().__init__(args, transformer=transformer)
         additional_seqlen = args.new_sequence_length - args.max_sequence_length
         self.mixins.append(PositionEmbeddingMixin(
-            additional_seqlen, args.hidden_size
+            additional_seqlen, args.hidden_size,
+            reinit_slice=slice(-180, None)
+        ))
+        self.mixins.append(WordEmebeddingMixin(
+            args.old_token_num, args.additional_token_num, args.hidden_size
         ))
         self.layout = args.layout
 
     def position_embedding_forward(self, position_ids, *other_tensors):
-        position = position_ids[..., :self.layout[1]]
-        position_plus = position_ids[..., self.layout[1]:]
+        # breakpoint()
+        position_text = position_ids[..., :self.layout[0]]
+        position_object = position_ids[..., self.layout[0]:self.layout[1]]
+        position_image = position_ids[..., self.layout[1]:]
         position_embeddings = torch.cat(
                 (
-                    self.transformer.position_embeddings(position),
-                    self.mixins[0].position_embeddings(position_plus)
+                    self.transformer.position_embeddings(position_text),
+                    self.mixins[0].position_embeddings(position_object),
+                    self.transformer.position_embeddings(position_image)
                 ),
                 dim=-2
             )
         return position_embeddings
 
+    def word_embedding_forward(self, input_ids, *other_tensors):
+        return self.mixins[1].word_embeddings(input_ids)
+
+    def final_forward(self, logits, *other_tensors):
+        logits = copy_to_model_parallel_region(logits)
+        logits = F.linear(logits, self.mixins[1].word_embeddings.weight)
+        return logits
+
     @classmethod
     def add_model_specific_args(cls, parser):
         group = parser.add_argument_group('ObjectModel', 'Object model configurations')
-        group.add_argument("--layout", type=str, default='64,1088')
-        group.add_argument("--new-sequence-length", type=int, default=5185)
+        group.add_argument("--layout", type=str, default='64,246,1270')
+        group.add_argument("--old-token-num", type=int, default=58219)
+        group.add_argument("--additional-token-num", type=int, default=257)
+        group.add_argument("--new-sequence-length", type=int, default=1271) #1089 + 180 + 2
         return parser
\ No newline at end of file
diff --git a/model/cached_object_model.py b/model/cached_object_model.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..6ef67d76ef63fbc905bc9b7155846066f1b9be99 100644
--- a/model/cached_object_model.py
+++ b/model/cached_object_model.py
@@ -0,0 +1,50 @@
+# -*- encoding: utf-8 -*-
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+
+from .base_model import BaseModel
+from .ObjectModel import ObjectModel
+from mpu.transformer import standard_attention, split_tensor_along_last_dim
+
+
+class CachedObjectModel(ObjectModel):
+    def __init__(self, args, transformer=None):
+        super().__init__(args, transformer=transformer)
+        self.log_attention_weights = None
+
+    def attention_forward(self, hidden_states, mask, *other_tensors, layer_id=None):
+        attn_module = self.transformer.layers[layer_id].attention
+        mem = other_tensors[layer_id] if len(other_tensors) > 0 else None
+
+        mixed_raw_layer = attn_module.query_key_value(hidden_states)
+        (mixed_query_layer,
+         mixed_key_layer,
+         mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
+
+        if mem is not None:  # the first time, mem is None
+            b = mixed_key_layer.shape[0]  # might change batch_size
+            memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2)
+            mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
+            mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
+
+        # same as training
+        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
+        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
+        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
+        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None,
+                                           log_attention_weights=self.log_attention_weights)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+        output = attn_module.dense(context_layer)
+
+        # new mem this layer
+        new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
+
+        return output, new_mem
diff --git a/model/mixins.py b/model/mixins.py
index b8063fb04001dea35a281af72683ad2b41073d64..b5e12dcce595d17d8cdfee591b8aaab564ac717f 100644
--- a/model/mixins.py
+++ b/model/mixins.py
@@ -69,9 +69,16 @@ class AttentionMixin(BaseMixin):
             self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
 
 class WordEmebeddingMixin(BaseMixin):
-    def __init__(self, additional_token_num, hidden_size):
+    def __init__(self, old_token_num, additional_token_num, hidden_size,
+                init_method_std=0.02, reinit_slice=slice(-1024, None)
+        ):
         super(WordEmebeddingMixin, self).__init__()
-        self.word_embeddings = torch.nn.Embedding(additional_token_num, hidden_size)
+        self.reinit_slice = reinit_slice
+        self.word_embeddings = torch.nn.Embedding(old_token_num + additional_token_num, hidden_size)
         torch.nn.init.normal_(self.word_embeddings.weight, mean=0.0, std=init_method_std)
+
     def reinit(self, transformer, *pre_mixins):
-        old_weights = transformer.
\ No newline at end of file
+        old_weights = transformer.word_embedding.weight.data
+        old_len, hidden_size = old_weights.shape
+        assert hidden_size == self.word_embeddings.weight.shape[-1]
+        self.word_embeddings.weight[:old_len].data.view(-1, old_len, hidden_size).copy_(old_weights)
\ No newline at end of file
diff --git a/pretrain_objectModel.py b/pretrain_objectModel.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0428db3b72aa3cfa67e9a8f1b7d6771de5d8ea7a 100644
--- a/pretrain_objectModel.py
+++ b/pretrain_objectModel.py
@@ -0,0 +1,207 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   pretrain_cogview2.py
+@Time    :   2021/10/06 00:58:32
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import argparse
+import numpy as np
+
+import mpu
+from arguments import get_args
+from model.ObjectModel import ObjectModel
+from training.deepspeed_training import training_main
+from data_utils import BinaryDataset
+from tokenization import get_tokenizer
+from tokenization.cogview import TextCodeTemplate
+
+def get_names():
+    return ["背景",
+            "人","自行车","汽车","摩托车","飞机","公交车","火车","卡车","船","红绿灯",
+            "消防栓","空", "停车牌","停车收费表","长椅","鸟","猫","狗","马","羊",
+            "牛","大象","熊","斑马","长颈鹿","空","背包","伞","空","空",
+            "手提包","领带","手提箱","飞盘","滑雪板","滑雪板","球","风筝","棒球棒","棒球手套",
+            "滑板","冲浪板","网球拍","瓶子","空", "酒杯","杯子","叉子","刀子","勺子",
+            "碗","香蕉","苹果","三明治","橘子","花椰菜","胡萝卜","热狗","比萨饼","甜甜圈",
+            "蛋糕","椅子","沙发","盆栽植物","床","空", "餐桌","空","空","厕所",
+            "空","电视","笔记本","鼠标","遥控器","键盘","手机","微波炉","烤箱","烤面包机",
+            "水槽","冰箱","空","书","钟","花瓶","剪刀","泰迪熊","吹风机","牙刷"]
+
+
+def get_masks_and_position_ids(data,
+                               n_pads,
+                               object_pads,
+                               loss_mask=None,
+                               attention_mask=None, args=None):
+    # Extract batch size and sequence length.
+    batch_size, seq_length = data.size()
+
+    # Attention mask (lower triangular).
+    if attention_mask is None:
+        assert loss_mask is not None
+        # loss_mask has n_pad(+1 CLS and [1:] then) zeros, so it is the same as attention_mask, reuse.
+        attention_mask = loss_mask[:, :seq_length].unsqueeze(-2).expand(batch_size, seq_length, seq_length).tril()
+        for i in range(batch_size):
+            attention_mask[i].fill_diagonal_(1)
+        attention_mask = attention_mask.unsqueeze(1)
+
+    # Loss mask.
+    if loss_mask is None:
+        loss_mask = torch.ones(data.size(), dtype=data.dtype, device=data.device)
+
+    # Position ids.
+    #1270
+    position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long,
+                                device=data.device)
+
+    for i in range(batch_size):
+        torch.arange(64 - n_pads[i], out=position_ids[i, n_pads[i]:64],
+                     dtype=torch.long, device=data.device)
+        torch.arange(180-object_pads[i], out=position_ids[i, 64+object_pads:64+180])
+        # breakpoint()
+        torch.arange(64 - n_pads[i],  64 - n_pads[i] + seq_length - (64+180),
+                     out=position_ids[i, 64+180:],
+                     dtype=torch.long, device=data.device)
+    return attention_mask, loss_mask, position_ids
+
+
+def get_batch(data_iterator, args, timers):
+    # Items and their type.
+    keys = ['text', 'loss_mask', 'object_pad', 'n_pad']
+    datatype = torch.int64
+
+    # Broadcast data.
+    timers('data loader').start()
+    if data_iterator is not None:
+        data = next(data_iterator)
+    else:
+        data = None
+    timers('data loader').stop()
+
+    data_b = mpu.broadcast_data(keys, data, datatype)
+    # Unpack.
+    # breakpoint()
+    tokens_ = data_b['text'].long()
+    loss_mask = data_b['loss_mask'].float()
+    n_pads = data_b['n_pad'].long()
+    object_pads = data_b['object_pad'].long()
+
+    labels = tokens_[:, 1:].contiguous()
+    loss_mask = loss_mask[:, 1:].contiguous()
+    tokens = tokens_[:, :-1].contiguous()
+
+    attention_mask = None
+
+    # Get the masks and postition ids.
+    attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
+        tokens,
+        n_pads,
+        object_pads,
+        loss_mask=loss_mask,
+        attention_mask=attention_mask,
+        args=args
+    )
+    # Convert
+    if args.fp16:
+        attention_mask = attention_mask.half()
+
+    return tokens, labels, loss_mask, attention_mask, position_ids
+
+
+def forward_step(data_iterator, model, args, timers):
+    """Forward step."""
+
+    # Get the batch.
+    timers('batch generator').start()
+    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
+        data_iterator, args, timers)
+    timers('batch generator').stop()
+    # Forward model.
+    logits, *mems = model(tokens, position_ids, attention_mask)
+    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels)
+    # scaling loss mask
+    loss_mask = loss_mask.view(-1)
+
+    losses = losses.view(-1) * loss_mask
+    loss = torch.sum(losses) / loss_mask.sum()
+
+    return loss, {}
+
+def create_dataset_function(path, args):
+    tokenizer = get_tokenizer()
+    layout = [100,164,1188]
+    names = get_names()
+    tokens = []
+    for name in names:
+        tokens.append(tokenizer.EncodeAsIds(name))
+    # 4 + 4 + 1 一个object需要9个token 20个需要 20*9 = 180个
+    def process_fn(row):
+        row = row.astype(np.int64)
+        codes = [row[layout[1]:layout[2]]]
+        text = row[layout[0]:layout[1]]
+        text = text[text > 0][:63]  # [ROI]
+        object_tokens = []
+        # print(row[:100])
+        for i in range(20):
+            object = row[i * 5: (i+1) * 5]
+            if object[0] == -1:
+                break
+            # print("object", object)
+            object[2] += object[0]
+            object[3] += object[1]
+            object_tokens.append(tokenizer['[POS0]'])
+            object_tokens.extend([object[j] + args.old_token_num for j in range(4)])
+            object_tokens.extend(tokens[object[4]])
+        object_tokens = np.array(object_tokens)
+        # print(object_tokens)
+        object_pad = 180 - object_tokens.shape[-1]
+        object_tokens = np.concatenate([
+            np.array([tokenizer['[PAD]']] * object_pad, dtype=np.int64),
+            object_tokens
+        ], axis = 0)
+        # 180
+        text_object = np.concatenate([text, np.array(object_tokens, dtype=np.int64)], axis=0)
+        # print(len(text), len(text_object))
+        # print(len(codes[0]))
+        merged = TextCodeTemplate(text_object, codes[0], tokenizer)
+        # print(len(merged), len(text))
+        n_pad = args.new_sequence_length - len(merged)
+        parts = [
+            np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
+            merged
+        ]
+        ret = np.concatenate(parts, axis=0)
+        return {'text': ret,
+                'loss_mask': np.array([0] * n_pad + [1] * (len(text) + 1) + [0] * object_pad + [1] * (182 - object_pad + 1025)),
+                'object_pad':object_pad,
+                'n_pad':n_pad
+                }
+
+    return BinaryDataset(path, process_fn, length_per_sample=layout[-1])
+
+
+
+if __name__ == '__main__':
+    py_parser = argparse.ArgumentParser(add_help=False)
+
+    py_parser.add_argument('--txt-loss-scale', type=float, default=1)
+
+    ObjectModel.add_model_specific_args(py_parser)
+
+    known, args_list = py_parser.parse_known_args()
+
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+
+    args.layout = [int(x) for x in args.layout.split(',')]
+
+    training_main(args, model_cls=ObjectModel, forward_step_function=forward_step,
+                  create_dataset_function=create_dataset_function)
diff --git a/scripts/finetune_into_object.sh b/scripts/finetune_into_object.sh
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9cae1aae45e65ec3c86c63804c6e81a617c26ec1 100644
--- a/scripts/finetune_into_object.sh
+++ b/scripts/finetune_into_object.sh
@@ -0,0 +1,60 @@
+#! /bin/bash
+
+# Change for multinode config
+
+NUM_WORKERS=4
+NUM_GPUS_PER_WORKER=8
+MP_SIZE=1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+
+full_data='/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_coco_detection_task/coco/coco.bin.part_0.cogdata'
+
+#       --mode finetune \
+#       --resume-dataloader \
+#       --load pretrained/cogview/cogview-base
+config_json="$script_dir/ds_config_zero.json"
+gpt_options=" \
+       --experiment-name finetune2-object-test \
+       --tokenizer-type cogview \
+       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
+       --model-parallel-size ${MP_SIZE} \
+       --num-layers 48 \
+       --hidden-size 2560 \
+       --num-attention-heads 40 \
+       --train-iters 200000 \
+       --train-data ${full_data} \
+       --split 949,50,1 \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .1 \
+       --checkpoint-activations \
+       --max-sequence-length 1089 \
+       --sandwich-ln \
+       --fp16 \
+       --save-interval 2000 \
+       --eval-interval 1000 \
+       --save $main_dir/checkpoints \
+       --mode finetune \
+       --resume-dataloader \
+       --load pretrained/cogview/cogview-base
+"
+       # --load pretrained/cogview/cogview-base
+
+
+gpt_options="${gpt_options}
+       --deepspeed \
+       --deepspeed_config ${config_json} \
+"
+
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_objectModel.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/scripts/kill_python.sh b/scripts/kill_python.sh
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..3900363636355cbdd970edaff9d622eef166337a 100644
--- a/scripts/kill_python.sh
+++ b/scripts/kill_python.sh
@@ -0,0 +1 @@
+pdsh -w ssh:node[1-3] "pkill -9 python"
\ No newline at end of file
diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh
index 5a09f7712528974023124dd510d3762623e7c8d9..78e665fb2ca7c9b872e5708ae2af8b0ae2bf47b2 100755
--- a/scripts/pretrain_multiple_nodes.sh
+++ b/scripts/pretrain_multiple_nodes.sh
@@ -2,7 +2,7 @@
 
 # Change for multinode config
 
-NUM_WORKERS=1
+NUM_WORKERS=3
 NUM_GPUS_PER_WORKER=8
 MP_SIZE=1
 
@@ -12,7 +12,6 @@ main_dir=$(dirname $script_dir)
 
 OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
 HOST_FILE_PATH="hostfile"
-HOST_FILE_PATH="hostfile_single"
 
 full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
 small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata"
diff --git a/scripts/text2image_object.sh b/scripts/text2image_object.sh
index 96e8a045bbb2f3d4f046b9949dc734b12ae0288e..4ab4b70b745a3c47dfdd03c1f9ef80645ef531b3 100755
--- a/scripts/text2image_object.sh
+++ b/scripts/text2image_object.sh
@@ -1,6 +1,6 @@
 #!/bin/bash
 
-CHECKPOINT_PATH=pretrained/cogview/cogview2-base
+CHECKPOINT_PATH=checkpoints/finetune2-object-test10-24-08-21
 NLAYERS=48
 NHIDDEN=2560
 NATT=40
@@ -15,9 +15,9 @@ TOPK=200
 script_path=$(realpath $0)
 script_dir=$(dirname $script_path)
 
-MASTER_PORT=${MASTER_PORT} python inference_cogview2.py \
+MASTER_PORT=${MASTER_PORT} python inference_object.py \
        --tokenizer-type cogview \
-       --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
+       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
        --mode inference \
        --distributed-backend nccl \
        --max-sequence-length 1089 \
@@ -32,10 +32,10 @@ MASTER_PORT=${MASTER_PORT} python inference_cogview2.py \
        --top_k $TOPK \
        --sandwich-ln \
        --input-source ./input.txt \
-       --output-path samples_text2image \
+       --output-path samples_text2image_object \
        --batch-size 4 \
        --max-inference-batch-size 8 \
-       --device 0 \
+       --device 7 \
        $@