From 7b90b24c2abafcf3b4195b6d3164be2b21d64fc4 Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Tue, 10 Aug 2021 14:16:47 +0000 Subject: [PATCH] add compact bin and bird animal example dataset --- arguments.py | 4 +++- data_utils/configure_data.py | 4 ++-- data_utils/datasets.py | 33 ++++++++++++++++++++++++++++++++- data_utils/templates.py | 11 ++++++++++- readme.md | 7 +++++-- scripts/pretrain_single_node.sh | 6 +++--- scripts/text2image.sh | 2 +- 7 files changed, 56 insertions(+), 11 deletions(-) diff --git a/arguments.py b/arguments.py index 3407a0f..9825dfb 100755 --- a/arguments.py +++ b/arguments.py @@ -260,7 +260,9 @@ def add_data_args(parser): group.add_argument('--dataset-type', type=str, default='TokenizedDataset', choices=['TokenizedDataset', - 'TextCodeDataset'], + 'TextCodeDataset', + 'CompactBinaryDataset' + ], help='what type of dataset to use') group.add_argument('--max-memory-length', type=int, default=2048, diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py index ee4860c..4e4cb0b 100755 --- a/data_utils/configure_data.py +++ b/data_utils/configure_data.py @@ -282,7 +282,7 @@ class RandomMappingDataset(data.Dataset): self.wrapped_data = ds def __len__(self): - return len(self.wrapped_data) * 60 + return len(self.wrapped_data) * 200 def __getitem__(self, index): rng = random.Random(index) @@ -301,7 +301,7 @@ def detect_new_datasets(args): found = [] for _p in os.listdir(args.new_dataset_path): p = os.path.join(args.new_dataset_path, _p) - if str(p).endswith('lmdb') and not str(os.path.abspath(p)) in current_datasets: + if (str(p).endswith('lmdb') or str(p).endswith('bin')) and not str(os.path.abspath(p)) in current_datasets: found.append(p) if len(found) == 0: return None diff --git a/data_utils/datasets.py b/data_utils/datasets.py index 470ac6c..5c1c183 100755 --- a/data_utils/datasets.py +++ b/data_utils/datasets.py @@ -59,7 +59,26 @@ class LMDBDataset(Dataset): row = pickle.loads(txn.get(key)) return self.process_fn(row) - + +class BinaryDataset(Dataset): + def __init__(self, path, process_fn, length_per_sample=64+1024, dtype='int32', preload=False, **kwargs): + assert length_per_sample is not None + self.length_per_sample = length_per_sample + self.dtype = np.dtype(dtype) + self.process_fn = process_fn + if preload: + self.bin = np.fromfile(path, dtype=self.dtype).reshape(-1, length_per_sample) + else: + with open(path, 'r') as fid: + nbytes = fid.seek(0, 2) + flen = fid.tell() // self.dtype.itemsize + self.bin = np.memmap(path, dtype=self.dtype, shape=(flen // length_per_sample, length_per_sample)) + + def __len__(self): + return self.bin.shape[0] + + def __getitem__(self, index): + return self.process_fn(self.bin[index]) def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): @@ -96,5 +115,17 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): return {'text': ret, 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) } + + elif dataset_type == 'CompactBinaryDataset': + DS_CLASS = BinaryDataset + def process_fn(row): + text, code = row[:64].astype(np.int64), row[64:].astype(np.int64) # must 64 + 1024 + text = text[text>-1] + ret = TextCodeTemplate(text, code) + ret, attention_mask_sep = pad_to_len(ret) + return {'text': ret, + 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) + } + return DS_CLASS(path, process_fn) diff --git a/data_utils/templates.py b/data_utils/templates.py index bce8b2e..d5d4f99 100755 --- a/data_utils/templates.py +++ b/data_utils/templates.py @@ -51,7 +51,16 @@ def concat_codes(*codes): def TextCodeTemplate(text, code): tokenizer = get_tokenizer() - text_ids = [tokenizer['[ROI1]']] + tokenizer(text) + if isinstance(text, str): + text_ids = [tokenizer['[ROI1]']] + tokenizer(text) + else: + text_ids = np.concatenate( + ( + np.array([tokenizer['[ROI1]']]), + text, + ), + axis=0 + ) code = tokenizer.wrap_code(code) return concat_codes(text_ids, code) diff --git a/readme.md b/readme.md index d874739..2fd61e0 100755 --- a/readme.md +++ b/readme.md @@ -50,7 +50,10 @@ wget https://cloud.tsinghua.edu.cn/f/71607a5dca69417baa8c/?dl=1 -O pretrained/vq ``` tar -xvf cogview-{base, sr, caption}.tar -C pretrained/cogview/ ``` -2. (Only for training tutorial, skip it for inference.) Download the Alibaba item-title image tokens dataset from our link at [Tianchi]()(*TODO*). Place the lmdb folder under `./data`. +2. (Only for training tutorial, skip it for inference.) Download a small "bird-and-animal" example dataset from our link at Tsinghua Cloud. +``` +wget https://cloud.tsinghua.edu.cn/f/1e4963ec8ac84941ba68/?dl=1 -O data/bird_animal.bin +``` ### Run CogView! (Model Inference) We encapsulate the generation functions into scripts. See `generate_samples.py` and `arguments.py` for details. @@ -95,7 +98,7 @@ The output is `{output_path}/scores.txt`, a line of a list of scores, following Note: *In the released codes, for simplicity, we did not expose the raw API , which supports some advanced generation modes, e.g. text and part of image.* ## Training -Here we use a subset of our dataset from Alibaba item-title for tutorial. +Here we use a subset of our dataset from bird-and-animal for tutorial. ### Single Node After downloading the dataset, directly run ``` diff --git a/scripts/pretrain_single_node.sh b/scripts/pretrain_single_node.sh index 0973d40..4988cce 100755 --- a/scripts/pretrain_single_node.sh +++ b/scripts/pretrain_single_node.sh @@ -17,9 +17,9 @@ HOST_FILE_PATH="hostfile_single" config_json="$script_dir/ds_config.json" gpt_options=" \ - --experiment-name cogview-ali_fashion_tutorial-12-1024-16 \ + --experiment-name cogview-bird_animal_tutorial-12-1024-16 \ --img-tokenizer-num-tokens 8192 \ - --dataset-type TokenizedDataset \ + --dataset-type CompactBinaryDataset \ --model-parallel-size ${MP_SIZE} \ --num-layers 12 \ --hidden-size 1024 \ @@ -27,7 +27,7 @@ gpt_options=" \ --save $main_dir/data/checkpoints \ --train-iters 20000 \ --resume-dataloader \ - --train-data ./data/ali_vqvae_hard_biggerset_011.lmdb \ + --train-data ./data/bird_animal.bin \ --split 949,50,1 \ --distributed-backend nccl \ --lr-decay-style cosine \ diff --git a/scripts/text2image.sh b/scripts/text2image.sh index 995e355..7e91844 100755 --- a/scripts/text2image.sh +++ b/scripts/text2image.sh @@ -1,7 +1,7 @@ #!/bin/bash # ==== tutorial settings: ===== -# CHECKPOINT_PATH=data/checkpoints/cogview-ali_fashion_tutorial-12-1024-1606-14-06-09 +# CHECKPOINT_PATH=data/checkpoints/cogview-bird_animal_tutorial-12-1024-1608-10-09-38 # NLAYERS=12 # NHIDDEN=1024 # NATT=16 -- GitLab