diff --git a/generation/object_sampling.py b/generation/object_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7de80e79c184060420e8dfb001f6fb801611c04
--- /dev/null
+++ b/generation/object_sampling.py
@@ -0,0 +1,123 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   autoregressive_sampling.py
+@Time    :   2021/10/08 15:43:59
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+from .sampling_strategies import BaseStrategy
+
+
+
+
+
+def update_mems(hiddens, mems, max_memory_length):
+    '''
+        hiddens: list (num_layers) of [batch, query_length, 2d]
+        mems: None or [num_layers, batch, memory_length, 2d]
+    '''
+    if hiddens is None:
+        return None
+    hiddens = torch.stack(hiddens)
+    memory_length = mems.shape[2] if mems is not None else 0
+    query_length = hiddens.shape[2]
+    new_memory_length = min(max_memory_length, memory_length + query_length)
+    with torch.no_grad():
+        if new_memory_length <= query_length:
+            return hiddens[:, :, -new_memory_length:]
+        else:
+            if mems.shape[1] < hiddens.shape[1]:
+                mems = mems.expand(-1, hiddens.shape[1], -1, -1)
+            return torch.cat(
+                (mems[:, :, -new_memory_length + query_length:], hiddens),
+                dim=2
+            )
+
+
+def filling_sequence_object(
+        model,
+        seq,
+        n_pads,
+        object_pads,
+        batch_size,
+        strategy=BaseStrategy(),
+        max_memory_length=100000,
+        log_attention_weights=None
+):
+    '''
+        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+    '''
+    assert len(seq.shape) == 1
+
+    # building the initial tokens, attention_mask, and position_ids
+    context_length = 0
+    while seq[context_length] >= 0:
+        context_length += 1  # [0, context_length-1] are given
+    assert context_length > 0
+    tokens = seq.unsqueeze(0)
+
+    # 和训练的时候不一样,前面没有pad
+    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+    attention_mask.tril_()
+    attention_mask.unsqueeze_(1)
+
+    position_ids = torch.zeros(1, len(seq), dtype=torch.long, device=tokens.device)
+
+    torch.arange(64 - n_pads, out=position_ids[0, n_pads:64],
+                 dtype=torch.long, device=tokens.device)
+    torch.arange(180 - object_pads, out=position_ids[0, 64 + object_pads:64 + 180])
+    # breakpoint()
+    torch.arange(64 - n_pads, 64 - n_pads + len(seq) - (64 + 180),
+                 out=position_ids[0, 64 + 180:],
+                 dtype=torch.long, device=tokens.device)
+    # breakpoint()
+
+    tokens = tokens[..., :context_length]
+    attention_mask = attention_mask.type_as(next(model.parameters()))  # if fp16
+    # initialize generation
+    counter = context_length - 1  # Last fixed index is ``counter''
+    index = 0  # Next forward starting index, also the length of cache.
+    mems = None  # mems are the first-level citizens here, but we don't assume what is memorized.
+
+    # step-by-step generation
+    while counter < len(seq) - 1:
+        # Now, we want to generate seq[counter + 1],
+        # token[:, index: counter+1] needs forwarding.
+
+        if seq[counter + 1] >= 0:  # provided
+            tokens = torch.cat(
+                (
+                    tokens,
+                    seq[counter + 1: counter + 2].expand(tokens.shape[0], 1)
+                ), dim=1
+            )
+            counter += 1
+            continue
+
+        # forward
+        if log_attention_weights is not None:
+            model.log_attention_weights = log_attention_weights[..., index: counter + 1, :counter + 1]  # TODO memlen
+        kw_tensors = {'mems': mems} if mems is not None else {}
+        logits, *mem_kv = model(
+            tokens[:, index:],
+            position_ids[..., index: counter + 1],
+            attention_mask[..., index: counter + 1, :counter + 1],  # TODO memlen
+            **kw_tensors  # if no mems, cannot pass
+        )
+        mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
+        counter += 1
+        index = counter
+        # sampling
+        logits = logits[:, -1].expand(batch_size, -1)  # [batch size, vocab size]
+        tokens = tokens.expand(batch_size, -1)
+        tokens, mems = strategy.forward(logits, tokens, mems)
+
+    model.log_attention_weights = None
+    return tokens
\ No newline at end of file
diff --git a/inference_object.py b/inference_object.py
index 576635af756dbdf36d4a99ecd6edcda5343fda3a..46c623b90da924e9cb69d66abf0118817a96170e 100644
--- a/inference_object.py
+++ b/inference_object.py
@@ -22,6 +22,7 @@ from training import load_checkpoint, initialize_distributed, set_random_seed, p
 from tokenization import get_tokenizer
 from generation.sampling_strategies import BaseStrategy
 from generation.autoregressive_sampling import filling_sequence
+from generation.object_sampling import filling_sequence_object
 from generation.utils import timed_name, save_multiple_images, generate_continually
 
 
@@ -48,20 +49,28 @@ def main(args):
         print('raw text: ', raw_text)
         raw_text = raw_text.split(' ')
         objects = raw_text[1:]
-        seq = tokenizer.parse_query(f"[ROI1] {raw_text[0]}", img_size=args.img_size)
+        seq1 = tokenizer.parse_query(f"[ROI1] {raw_text[0]}", img_size=args.img_size)
+        seq2 = []
         for i in range(len(objects)//5):
-            seq.append(tokenizer['POS0'])
-            seq.extend(objects[i*5:i*5+4] + args.old_token_num)
-            seq.extend(tokenizer.EncodeAsIds(objects[i*5 + 4]))
-        seq.extend(tokenizer.parse_query('[BASE] [BOI1] [MASK]*1024', img_size=args.img_size))
-
+            seq2.append(tokenizer['[POS0]'])
+            seq2.extend([int(x)+args.old_token_num for x in objects[i*5:i*5+4]])
+            seq2.extend(tokenizer.EncodeAsIds(objects[i*5 + 4]))
+        object_pads = 180 - len(seq2)
+        seq2 = [tokenizer['[PAD]']] * object_pads + seq2
+        seq3 = tokenizer.parse_query('[BASE] [BOI1] [MASK]*1024', img_size=args.img_size)
+        seq = seq1 + seq2 + seq3
+        n_pads = args.new_sequence_length - 1 - len(seq)
+        seq = ([tokenizer['[PAD]']] * n_pads) + seq
+        print("len seq is ", len(seq))
         if len(seq) > 1271:
             raise ValueError('text too long.')
         # calibrate text length
-        txt_len = seq.index(tokenizer['[BASE]'])
+        front_len = seq.index(tokenizer['[BASE]'])
         log_attention_weights = torch.zeros(len(seq), len(seq),
                                             device=args.device, dtype=torch.half if args.fp16 else torch.float32)
-        log_attention_weights[txt_len + 2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4  # TODO args
+        # TODO text attention
+        # log_attention_weights[front_len + 2:, 1:front_len] = 1.8 if front_len <= 10 else 1.4  # TODO args
+
         # generation
         seq = torch.cuda.LongTensor(seq, device=args.device)
         mbz = args.max_inference_batch_size
@@ -69,18 +78,21 @@ def main(args):
         output_list = []
         for tim in range(max(args.batch_size // mbz, 1)):
             output_list.append(
-                filling_sequence(model, seq.clone(),
-                                 batch_size=min(args.batch_size, mbz),
-                                 strategy=strategy,
-                                 log_attention_weights=log_attention_weights
-                                 )
+                filling_sequence_object(model, seq.clone(),
+                                        n_pads = n_pads,
+                                        object_pads = object_pads,
+                                        batch_size=min(args.batch_size, mbz),
+                                        strategy=strategy,
+                                        log_attention_weights=log_attention_weights
+                                        )
             )
         output_tokens = torch.cat(output_list, dim=0)
         # decoding
         imgs, txts = [], []
         for seq in output_tokens:
-            txt_len = seq.index(tokenizer['[BASE]'])
-            seq = seq[txt_len:]
+            # txt_len = seq.index(tokenizer['[BASE]'])
+            breakpoint()
+            seq = seq[-1025:]
             _, decoded_imgs = tokenizer.DecodeIds(seq.tolist())
             imgs.append(decoded_imgs[-1])  # only the last image (target)
         # save
@@ -89,7 +101,7 @@ def main(args):
             os.makedirs(full_path, exist_ok=True)
             save_multiple_images(imgs, full_path, False)
         else:
-            prefix = raw_text.replace('/', '')[:20]
+            prefix = raw_text[0].replace('/', '')[:20]
             full_path = timed_name(prefix, '.jpg', args.output_path)
             save_multiple_images(imgs, full_path, True)
 
@@ -105,6 +117,6 @@ if __name__ == "__main__":
     known, args_list = py_parser.parse_known_args()
     args = get_args(args_list)
     args = argparse.Namespace(**vars(args), **vars(known))
-
+    args.layout = [int(x) for x in args.layout.split(',')]
     with torch.no_grad():
         main(args)
\ No newline at end of file
diff --git a/model/ObjectModel.py b/model/ObjectModel.py
index 1000fe1f7594cc44c9cd65e9387e0f03d9b18c37..0adf84a35218da0acca2bc4d5d45ed52b740f773 100644
--- a/model/ObjectModel.py
+++ b/model/ObjectModel.py
@@ -24,7 +24,7 @@ class ObjectModel(BaseModel):
         ))
         self.layout = args.layout
 
-    def position_embedding_forward(self, position_ids, *other_tensors):
+    def position_embedding_forward(self, position_ids, **kw_tensors):
         # breakpoint()
         position_text = position_ids[..., :self.layout[0]]
         position_object = position_ids[..., self.layout[0]:self.layout[1]]
@@ -39,10 +39,10 @@ class ObjectModel(BaseModel):
             )
         return position_embeddings
 
-    def word_embedding_forward(self, input_ids, *other_tensors):
+    def word_embedding_forward(self, input_ids, **kw_tensors):
         return self.mixins[1].word_embeddings(input_ids)
 
-    def final_forward(self, logits, *other_tensors):
+    def final_forward(self, logits, **kw_tensors):
         logits = copy_to_model_parallel_region(logits)
         logits = F.linear(logits, self.mixins[1].word_embeddings.weight)
         return logits
@@ -50,7 +50,7 @@ class ObjectModel(BaseModel):
     @classmethod
     def add_model_specific_args(cls, parser):
         group = parser.add_argument_group('ObjectModel', 'Object model configurations')
-        group.add_argument("--layout", type=str, default='64,246,1270')
+        group.add_argument("--layout", type=str, default='64,246,1270') #246 = 64 + 182
         group.add_argument("--old-token-num", type=int, default=58219)
         group.add_argument("--additional-token-num", type=int, default=257)
         group.add_argument("--new-sequence-length", type=int, default=1271) #1089 + 180 + 2
diff --git a/model/cached_object_model.py b/model/cached_object_model.py
index 6ef67d76ef63fbc905bc9b7155846066f1b9be99..dc3f797d1dbe65261cfd4bad69e25c62f574753c 100644
--- a/model/cached_object_model.py
+++ b/model/cached_object_model.py
@@ -17,9 +17,9 @@ class CachedObjectModel(ObjectModel):
         super().__init__(args, transformer=transformer)
         self.log_attention_weights = None
 
-    def attention_forward(self, hidden_states, mask, *other_tensors, layer_id=None):
+    def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, **kwargs):
         attn_module = self.transformer.layers[layer_id].attention
-        mem = other_tensors[layer_id] if len(other_tensors) > 0 else None
+        mem = mems[layer_id] if mems is not None else None
 
         mixed_raw_layer = attn_module.query_key_value(hidden_states)
         (mixed_query_layer,
@@ -48,3 +48,4 @@ class CachedObjectModel(ObjectModel):
         new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
 
         return output, new_mem
+
diff --git a/pretrain_objectModel.py b/pretrain_objectModel.py
index 0428db3b72aa3cfa67e9a8f1b7d6771de5d8ea7a..c1ecd96a73c0ffc3039d674a2fa5bcf4db3f3ea5 100644
--- a/pretrain_objectModel.py
+++ b/pretrain_objectModel.py
@@ -62,14 +62,15 @@ def get_masks_and_position_ids(data,
     position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long,
                                 device=data.device)
 
+    #文字和图片部分用原来的position,object部分用新的从零开始的position
     for i in range(batch_size):
         torch.arange(64 - n_pads[i], out=position_ids[i, n_pads[i]:64],
-                     dtype=torch.long, device=data.device)
-        torch.arange(180-object_pads[i], out=position_ids[i, 64+object_pads:64+180])
+                     dtype=torch.long, device=data.device)#text
+        torch.arange(180-object_pads[i], out=position_ids[i, 64+object_pads:64+180])#object
         # breakpoint()
         torch.arange(64 - n_pads[i],  64 - n_pads[i] + seq_length - (64+180),
                      out=position_ids[i, 64+180:],
-                     dtype=torch.long, device=data.device)
+                     dtype=torch.long, device=data.device)#image
     return attention_mask, loss_mask, position_ids
 
 
@@ -97,7 +98,6 @@ def get_batch(data_iterator, args, timers):
     labels = tokens_[:, 1:].contiguous()
     loss_mask = loss_mask[:, 1:].contiguous()
     tokens = tokens_[:, :-1].contiguous()
-
     attention_mask = None
 
     # Get the masks and postition ids.
@@ -112,7 +112,6 @@ def get_batch(data_iterator, args, timers):
     # Convert
     if args.fp16:
         attention_mask = attention_mask.half()
-
     return tokens, labels, loss_mask, attention_mask, position_ids
 
 
@@ -137,50 +136,51 @@ def forward_step(data_iterator, model, args, timers):
 
 def create_dataset_function(path, args):
     tokenizer = get_tokenizer()
-    layout = [100,164,1188]
+    layout = [100,164,1188] #bin文件里的layout
     names = get_names()
     tokens = []
     for name in names:
         tokens.append(tokenizer.EncodeAsIds(name))
-    # 4 + 4 + 1 一个object需要9个token 20个需要 20*9 = 180个
+
+    # 4坐标 + 4文字 + 1POS0 一个object需要9个token 最多20个object 20个需要 20*9 = 180个
+    # 布局
+    # 前64 : [PAD] * n + [ROI] + text
+    # 64 ~ 64+180 : [PAD] * n + ([POS0] + xmin + ymin + xmax + ymax + class) * m
+    # 64+180 ~ 最后 : [BASE] [BOI] imageTokens [EOI]
     def process_fn(row):
         row = row.astype(np.int64)
         codes = [row[layout[1]:layout[2]]]
         text = row[layout[0]:layout[1]]
         text = text[text > 0][:63]  # [ROI]
         object_tokens = []
-        # print(row[:100])
-        for i in range(20):
+
+        for i in range(20): #最多20个object
             object = row[i * 5: (i+1) * 5]
             if object[0] == -1:
                 break
-            # print("object", object)
             object[2] += object[0]
             object[3] += object[1]
             object_tokens.append(tokenizer['[POS0]'])
-            object_tokens.extend([object[j] + args.old_token_num for j in range(4)])
+            object_tokens.extend([object[j] + args.old_token_num for j in range(4)]) #坐标的token是原来的token数量+坐标值
             object_tokens.extend(tokens[object[4]])
+
         object_tokens = np.array(object_tokens)
-        # print(object_tokens)
-        object_pad = 180 - object_tokens.shape[-1]
+        object_pad = 180 - object_tokens.shape[-1] #object部分补到180
         object_tokens = np.concatenate([
             np.array([tokenizer['[PAD]']] * object_pad, dtype=np.int64),
             object_tokens
         ], axis = 0)
-        # 180
         text_object = np.concatenate([text, np.array(object_tokens, dtype=np.int64)], axis=0)
-        # print(len(text), len(text_object))
-        # print(len(codes[0]))
         merged = TextCodeTemplate(text_object, codes[0], tokenizer)
-        # print(len(merged), len(text))
-        n_pad = args.new_sequence_length - len(merged)
+
+        n_pad = args.new_sequence_length - len(merged) #整体补到最大长度
         parts = [
             np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
             merged
         ]
         ret = np.concatenate(parts, axis=0)
         return {'text': ret,
-                'loss_mask': np.array([0] * n_pad + [1] * (len(text) + 1) + [0] * object_pad + [1] * (182 - object_pad + 1025)),
+                'loss_mask': np.array([0] * n_pad + [1] * (len(text) + 1) + [0] * object_pad + [1] * (182 - object_pad + 1025)), #两个PAD的部分loss mask 为0
                 'object_pad':object_pad,
                 'n_pad':n_pad
                 }
diff --git a/scripts/text2image_object.sh b/scripts/text2image_object.sh
index 296b9daab9a19eddc1a35c68bcab9fa0141abe2b..2b5429962aa7648714d376e659171fe32f168dec 100755
--- a/scripts/text2image_object.sh
+++ b/scripts/text2image_object.sh
@@ -1,6 +1,6 @@
 #!/bin/bash
 
-CHECKPOINT_PATH=checkpoints/finetune2-object-test10-24-12-19
+CHECKPOINT_PATH=checkpoints/finetune2-object-test10-24-12-29
 NLAYERS=48
 NHIDDEN=2560
 NATT=40