From 7b90b24c2abafcf3b4195b6d3164be2b21d64fc4 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Tue, 10 Aug 2021 14:16:47 +0000
Subject: [PATCH] add compact bin and bird animal example dataset

---
 arguments.py                    |  4 +++-
 data_utils/configure_data.py    |  4 ++--
 data_utils/datasets.py          | 33 ++++++++++++++++++++++++++++++++-
 data_utils/templates.py         | 11 ++++++++++-
 readme.md                       |  7 +++++--
 scripts/pretrain_single_node.sh |  6 +++---
 scripts/text2image.sh           |  2 +-
 7 files changed, 56 insertions(+), 11 deletions(-)

diff --git a/arguments.py b/arguments.py
index 3407a0f..9825dfb 100755
--- a/arguments.py
+++ b/arguments.py
@@ -260,7 +260,9 @@ def add_data_args(parser):
     group.add_argument('--dataset-type', type=str,
                        default='TokenizedDataset',
                        choices=['TokenizedDataset',
-                                'TextCodeDataset'],
+                                'TextCodeDataset',
+                                'CompactBinaryDataset'
+                                ],
                        help='what type of dataset to use')
 
     group.add_argument('--max-memory-length', type=int, default=2048,
diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py
index ee4860c..4e4cb0b 100755
--- a/data_utils/configure_data.py
+++ b/data_utils/configure_data.py
@@ -282,7 +282,7 @@ class RandomMappingDataset(data.Dataset):
         self.wrapped_data = ds
 
     def __len__(self):
-        return len(self.wrapped_data) * 60
+        return len(self.wrapped_data) * 200
 
     def __getitem__(self, index):
         rng = random.Random(index)
@@ -301,7 +301,7 @@ def detect_new_datasets(args):
     found = []
     for _p in os.listdir(args.new_dataset_path):
         p = os.path.join(args.new_dataset_path, _p)
-        if str(p).endswith('lmdb') and not str(os.path.abspath(p)) in current_datasets:
+        if (str(p).endswith('lmdb') or str(p).endswith('bin')) and not str(os.path.abspath(p)) in current_datasets:
             found.append(p)
     if len(found) == 0:
         return None
diff --git a/data_utils/datasets.py b/data_utils/datasets.py
index 470ac6c..5c1c183 100755
--- a/data_utils/datasets.py
+++ b/data_utils/datasets.py
@@ -59,7 +59,26 @@ class LMDBDataset(Dataset):
             row = pickle.loads(txn.get(key))
 
             return self.process_fn(row)
-        
+
+class BinaryDataset(Dataset):
+    def __init__(self, path, process_fn, length_per_sample=64+1024, dtype='int32', preload=False, **kwargs):
+        assert length_per_sample is not None
+        self.length_per_sample = length_per_sample
+        self.dtype = np.dtype(dtype)
+        self.process_fn = process_fn
+        if preload:
+            self.bin = np.fromfile(path, dtype=self.dtype).reshape(-1, length_per_sample)
+        else:
+            with open(path, 'r') as fid:
+                nbytes = fid.seek(0, 2)
+                flen = fid.tell() // self.dtype.itemsize
+            self.bin = np.memmap(path, dtype=self.dtype, shape=(flen // length_per_sample, length_per_sample))
+    
+    def __len__(self):
+        return self.bin.shape[0]
+    
+    def __getitem__(self, index):
+        return self.process_fn(self.bin[index])
 
 def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):       
 
@@ -96,5 +115,17 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
             return {'text': ret, 
                 'loss_mask':  np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
                 }
+
+    elif dataset_type == 'CompactBinaryDataset':
+        DS_CLASS = BinaryDataset
+        def process_fn(row):
+            text, code = row[:64].astype(np.int64), row[64:].astype(np.int64) # must 64 + 1024
+            text = text[text>-1]
+            ret = TextCodeTemplate(text, code)
+            ret, attention_mask_sep = pad_to_len(ret)
+            return {'text': ret, 
+                'loss_mask':  np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
+                }
+
     return DS_CLASS(path, process_fn)
 
diff --git a/data_utils/templates.py b/data_utils/templates.py
index bce8b2e..d5d4f99 100755
--- a/data_utils/templates.py
+++ b/data_utils/templates.py
@@ -51,7 +51,16 @@ def concat_codes(*codes):
 
 def TextCodeTemplate(text, code):
     tokenizer = get_tokenizer()
-    text_ids = [tokenizer['[ROI1]']] + tokenizer(text)
+    if isinstance(text, str):
+        text_ids = [tokenizer['[ROI1]']] + tokenizer(text)
+    else:
+        text_ids = np.concatenate(
+                (
+                    np.array([tokenizer['[ROI1]']]),
+                    text,
+                ),
+                axis=0
+            )
     code = tokenizer.wrap_code(code)
     return concat_codes(text_ids, code)
 
diff --git a/readme.md b/readme.md
index d874739..2fd61e0 100755
--- a/readme.md
+++ b/readme.md
@@ -50,7 +50,10 @@ wget https://cloud.tsinghua.edu.cn/f/71607a5dca69417baa8c/?dl=1 -O pretrained/vq
     ```
     tar -xvf cogview-{base, sr, caption}.tar -C pretrained/cogview/
     ```
-2. (Only for training tutorial, skip it for inference.) Download the Alibaba item-title image tokens dataset from our link at [Tianchi]()(*TODO*). Place the lmdb folder under `./data`.
+2. (Only for training tutorial, skip it for inference.) Download a small "bird-and-animal" example dataset from our link at Tsinghua Cloud.
+```
+wget https://cloud.tsinghua.edu.cn/f/1e4963ec8ac84941ba68/?dl=1 -O data/bird_animal.bin
+```
 
 ### Run CogView! (Model Inference)
 We encapsulate the generation functions into scripts. See `generate_samples.py` and `arguments.py` for details.
@@ -95,7 +98,7 @@ The output is `{output_path}/scores.txt`, a line of a list of scores, following
 Note: *In the released codes, for simplicity, we did not expose the raw API , which supports some advanced generation modes, e.g. text and part of image.*
 
 ## Training
-Here we use a subset of our dataset from Alibaba item-title for tutorial.
+Here we use a subset of our dataset from bird-and-animal for tutorial.
 ### Single Node 
 After downloading the dataset, directly run
 ```
diff --git a/scripts/pretrain_single_node.sh b/scripts/pretrain_single_node.sh
index 0973d40..4988cce 100755
--- a/scripts/pretrain_single_node.sh
+++ b/scripts/pretrain_single_node.sh
@@ -17,9 +17,9 @@ HOST_FILE_PATH="hostfile_single"
 
 config_json="$script_dir/ds_config.json"
 gpt_options=" \
-       --experiment-name cogview-ali_fashion_tutorial-12-1024-16 \
+       --experiment-name cogview-bird_animal_tutorial-12-1024-16 \
        --img-tokenizer-num-tokens 8192 \
-       --dataset-type TokenizedDataset \
+       --dataset-type CompactBinaryDataset \
        --model-parallel-size ${MP_SIZE} \
        --num-layers 12 \
        --hidden-size 1024 \
@@ -27,7 +27,7 @@ gpt_options=" \
        --save $main_dir/data/checkpoints \
        --train-iters 20000 \
        --resume-dataloader \
-       --train-data ./data/ali_vqvae_hard_biggerset_011.lmdb \
+       --train-data ./data/bird_animal.bin \
        --split 949,50,1 \
        --distributed-backend nccl \
        --lr-decay-style cosine \
diff --git a/scripts/text2image.sh b/scripts/text2image.sh
index 995e355..7e91844 100755
--- a/scripts/text2image.sh
+++ b/scripts/text2image.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
 # ==== tutorial settings: =====
-# CHECKPOINT_PATH=data/checkpoints/cogview-ali_fashion_tutorial-12-1024-1606-14-06-09
+# CHECKPOINT_PATH=data/checkpoints/cogview-bird_animal_tutorial-12-1024-1608-10-09-38
 # NLAYERS=12
 # NHIDDEN=1024
 # NATT=16
-- 
GitLab