Skip to content
Snippets Groups Projects
Commit 7b90b24c authored by Ming Ding's avatar Ming Ding
Browse files

add compact bin and bird animal example dataset

parent 13485ea2
No related branches found
No related tags found
No related merge requests found
...@@ -260,7 +260,9 @@ def add_data_args(parser): ...@@ -260,7 +260,9 @@ def add_data_args(parser):
group.add_argument('--dataset-type', type=str, group.add_argument('--dataset-type', type=str,
default='TokenizedDataset', default='TokenizedDataset',
choices=['TokenizedDataset', choices=['TokenizedDataset',
'TextCodeDataset'], 'TextCodeDataset',
'CompactBinaryDataset'
],
help='what type of dataset to use') help='what type of dataset to use')
group.add_argument('--max-memory-length', type=int, default=2048, group.add_argument('--max-memory-length', type=int, default=2048,
......
...@@ -282,7 +282,7 @@ class RandomMappingDataset(data.Dataset): ...@@ -282,7 +282,7 @@ class RandomMappingDataset(data.Dataset):
self.wrapped_data = ds self.wrapped_data = ds
def __len__(self): def __len__(self):
return len(self.wrapped_data) * 60 return len(self.wrapped_data) * 200
def __getitem__(self, index): def __getitem__(self, index):
rng = random.Random(index) rng = random.Random(index)
...@@ -301,7 +301,7 @@ def detect_new_datasets(args): ...@@ -301,7 +301,7 @@ def detect_new_datasets(args):
found = [] found = []
for _p in os.listdir(args.new_dataset_path): for _p in os.listdir(args.new_dataset_path):
p = os.path.join(args.new_dataset_path, _p) p = os.path.join(args.new_dataset_path, _p)
if str(p).endswith('lmdb') and not str(os.path.abspath(p)) in current_datasets: if (str(p).endswith('lmdb') or str(p).endswith('bin')) and not str(os.path.abspath(p)) in current_datasets:
found.append(p) found.append(p)
if len(found) == 0: if len(found) == 0:
return None return None
......
...@@ -59,7 +59,26 @@ class LMDBDataset(Dataset): ...@@ -59,7 +59,26 @@ class LMDBDataset(Dataset):
row = pickle.loads(txn.get(key)) row = pickle.loads(txn.get(key))
return self.process_fn(row) return self.process_fn(row)
class BinaryDataset(Dataset):
def __init__(self, path, process_fn, length_per_sample=64+1024, dtype='int32', preload=False, **kwargs):
assert length_per_sample is not None
self.length_per_sample = length_per_sample
self.dtype = np.dtype(dtype)
self.process_fn = process_fn
if preload:
self.bin = np.fromfile(path, dtype=self.dtype).reshape(-1, length_per_sample)
else:
with open(path, 'r') as fid:
nbytes = fid.seek(0, 2)
flen = fid.tell() // self.dtype.itemsize
self.bin = np.memmap(path, dtype=self.dtype, shape=(flen // length_per_sample, length_per_sample))
def __len__(self):
return self.bin.shape[0]
def __getitem__(self, index):
return self.process_fn(self.bin[index])
def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
...@@ -96,5 +115,17 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): ...@@ -96,5 +115,17 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
return {'text': ret, return {'text': ret,
'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
} }
elif dataset_type == 'CompactBinaryDataset':
DS_CLASS = BinaryDataset
def process_fn(row):
text, code = row[:64].astype(np.int64), row[64:].astype(np.int64) # must 64 + 1024
text = text[text>-1]
ret = TextCodeTemplate(text, code)
ret, attention_mask_sep = pad_to_len(ret)
return {'text': ret,
'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
}
return DS_CLASS(path, process_fn) return DS_CLASS(path, process_fn)
...@@ -51,7 +51,16 @@ def concat_codes(*codes): ...@@ -51,7 +51,16 @@ def concat_codes(*codes):
def TextCodeTemplate(text, code): def TextCodeTemplate(text, code):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
text_ids = [tokenizer['[ROI1]']] + tokenizer(text) if isinstance(text, str):
text_ids = [tokenizer['[ROI1]']] + tokenizer(text)
else:
text_ids = np.concatenate(
(
np.array([tokenizer['[ROI1]']]),
text,
),
axis=0
)
code = tokenizer.wrap_code(code) code = tokenizer.wrap_code(code)
return concat_codes(text_ids, code) return concat_codes(text_ids, code)
......
...@@ -50,7 +50,10 @@ wget https://cloud.tsinghua.edu.cn/f/71607a5dca69417baa8c/?dl=1 -O pretrained/vq ...@@ -50,7 +50,10 @@ wget https://cloud.tsinghua.edu.cn/f/71607a5dca69417baa8c/?dl=1 -O pretrained/vq
``` ```
tar -xvf cogview-{base, sr, caption}.tar -C pretrained/cogview/ tar -xvf cogview-{base, sr, caption}.tar -C pretrained/cogview/
``` ```
2. (Only for training tutorial, skip it for inference.) Download the Alibaba item-title image tokens dataset from our link at [Tianchi]()(*TODO*). Place the lmdb folder under `./data`. 2. (Only for training tutorial, skip it for inference.) Download a small "bird-and-animal" example dataset from our link at Tsinghua Cloud.
```
wget https://cloud.tsinghua.edu.cn/f/1e4963ec8ac84941ba68/?dl=1 -O data/bird_animal.bin
```
### Run CogView! (Model Inference) ### Run CogView! (Model Inference)
We encapsulate the generation functions into scripts. See `generate_samples.py` and `arguments.py` for details. We encapsulate the generation functions into scripts. See `generate_samples.py` and `arguments.py` for details.
...@@ -95,7 +98,7 @@ The output is `{output_path}/scores.txt`, a line of a list of scores, following ...@@ -95,7 +98,7 @@ The output is `{output_path}/scores.txt`, a line of a list of scores, following
Note: *In the released codes, for simplicity, we did not expose the raw API , which supports some advanced generation modes, e.g. text and part of image.* Note: *In the released codes, for simplicity, we did not expose the raw API , which supports some advanced generation modes, e.g. text and part of image.*
## Training ## Training
Here we use a subset of our dataset from Alibaba item-title for tutorial. Here we use a subset of our dataset from bird-and-animal for tutorial.
### Single Node ### Single Node
After downloading the dataset, directly run After downloading the dataset, directly run
``` ```
......
...@@ -17,9 +17,9 @@ HOST_FILE_PATH="hostfile_single" ...@@ -17,9 +17,9 @@ HOST_FILE_PATH="hostfile_single"
config_json="$script_dir/ds_config.json" config_json="$script_dir/ds_config.json"
gpt_options=" \ gpt_options=" \
--experiment-name cogview-ali_fashion_tutorial-12-1024-16 \ --experiment-name cogview-bird_animal_tutorial-12-1024-16 \
--img-tokenizer-num-tokens 8192 \ --img-tokenizer-num-tokens 8192 \
--dataset-type TokenizedDataset \ --dataset-type CompactBinaryDataset \
--model-parallel-size ${MP_SIZE} \ --model-parallel-size ${MP_SIZE} \
--num-layers 12 \ --num-layers 12 \
--hidden-size 1024 \ --hidden-size 1024 \
...@@ -27,7 +27,7 @@ gpt_options=" \ ...@@ -27,7 +27,7 @@ gpt_options=" \
--save $main_dir/data/checkpoints \ --save $main_dir/data/checkpoints \
--train-iters 20000 \ --train-iters 20000 \
--resume-dataloader \ --resume-dataloader \
--train-data ./data/ali_vqvae_hard_biggerset_011.lmdb \ --train-data ./data/bird_animal.bin \
--split 949,50,1 \ --split 949,50,1 \
--distributed-backend nccl \ --distributed-backend nccl \
--lr-decay-style cosine \ --lr-decay-style cosine \
......
#!/bin/bash #!/bin/bash
# ==== tutorial settings: ===== # ==== tutorial settings: =====
# CHECKPOINT_PATH=data/checkpoints/cogview-ali_fashion_tutorial-12-1024-1606-14-06-09 # CHECKPOINT_PATH=data/checkpoints/cogview-bird_animal_tutorial-12-1024-1608-10-09-38
# NLAYERS=12 # NLAYERS=12
# NHIDDEN=1024 # NHIDDEN=1024
# NATT=16 # NATT=16
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment