From 142df57c84a5fd1931a75421eb71fbdd9ffcd318 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Wed, 6 Oct 2021 13:36:00 +0000
Subject: [PATCH] tmp save 4

---
 data_utils/__init__.py          |   6 +-
 data_utils/configure_data.py    |  70 ++++-------
 data_utils/datasets.py          |  79 +-----------
 data_utils/samplers.py          |  18 +--
 data_utils/sp_tokenizer.py      | 150 ----------------------
 data_utils/templates.py         |  83 -------------
 data_utils/unified_tokenizer.py | 212 --------------------------------
 data_utils/vqvae_tokenizer.py   |  86 -------------
 mpu/initialize.py               |   2 +-
 pretrain_gpt2.py                |  27 +++-
 training/deepspeed_training.py  |  17 +--
 11 files changed, 71 insertions(+), 679 deletions(-)
 delete mode 100755 data_utils/sp_tokenizer.py
 delete mode 100755 data_utils/templates.py
 delete mode 100755 data_utils/unified_tokenizer.py
 delete mode 100755 data_utils/vqvae_tokenizer.py

diff --git a/data_utils/__init__.py b/data_utils/__init__.py
index 509841d..394d3ea 100755
--- a/data_utils/__init__.py
+++ b/data_utils/__init__.py
@@ -8,7 +8,5 @@
 
 # here put the import lib
 
-from .unified_tokenizer import get_tokenizer
-
-from .templates import *
-from .configure_data import make_loaders, detect_new_datasets
\ No newline at end of file
+from .configure_data import make_loaders
+from .datasets import *
\ No newline at end of file
diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py
index 08d6a59..6556673 100755
--- a/data_utils/configure_data.py
+++ b/data_utils/configure_data.py
@@ -11,16 +11,13 @@ import os
 import sys
 import math
 import random
-from tqdm import tqdm
 import copy
 
 import numpy as np
 import torch
-import torch.nn.functional as F
 from bisect import bisect_right
+from functools import partial
 
-from .unified_tokenizer import get_tokenizer
-from .datasets import get_dataset_by_type
 from torch.utils import data
 from .samplers import DistributedBatchSampler
 
@@ -36,46 +33,37 @@ def make_data_loader(dataset, batch_size, num_iters, args):
     sampler = torch.utils.data.SequentialSampler(dataset)
     drop_last = distributed
     # the GPUs in the same model parallel group receive the same data
-    if distributed:
+    if distributed: # TODO reformat this, but it is not urgent
         gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
         batch_sampler = DistributedBatchSampler(sampler,
-                                                                    batch_size,
-                                                                    drop_last,
-                                                                    rank,
-                                                                    world_size,
-                                                                    gradient_accumulation_steps=gradient_accumulation_steps)
+            batch_size,
+            drop_last,
+            rank,
+            world_size,
+            gradient_accumulation_steps=gradient_accumulation_steps)
     else:
         batch_sampler = torch.utils.data.BatchSampler(sampler,
                                                         batch_size,
                                                         drop_last)
     data_loader = torch.utils.data.DataLoader(dataset,
-                                              batch_sampler=batch_sampler,
-                                              num_workers=args.num_workers,
-                                              pin_memory=True)
+                                            batch_sampler=batch_sampler,
+                                            num_workers=args.num_workers,
+                                            pin_memory=True)
     return data_loader
 
 
-def make_dataset(dataset_type, path, split, args, **kwargs):
+def make_dataset_full(dataset_type, path, split, args, create_dataset_function, **kwargs):
     """function to create datasets+tokenizers for common options"""
     print('make dataset ...', path)
     if split is None:
         split = [1.]
 
     assert isinstance(path, list)
-    # TODO other dsclass, e.g. odps
-    # ds = [get_dataset_by_type(dataset_type, p, args) for p in path]
-    # dataset object can be copied N times
+    
     ds = []
     for p in path:
-        d = get_dataset_by_type(dataset_type, p, args)
-        if p.find('t2i') >= 0:
-            ds.extend([d] * 4)
-            print(f'Enlarge {p} 4 times...')
-        elif p.find('i2t') >= 0:
-            ds.extend([d] * 2)
-            print(f'Enlarge {p} 2 times...')
-        else:
-            ds.append(d)
+        d = create_dataset_function(p, args)
+        ds.append(d)
 
     ds = RandomMappingDataset(ConcatDataset(ds))
 
@@ -84,8 +72,15 @@ def make_dataset(dataset_type, path, split, args, **kwargs):
     # FIXME this will merge valid set and train set.
     return ds
 
-def make_loaders(args):
-    """makes training/val/test"""
+def make_loaders(args, create_dataset_function):
+    """makes training/val/test
+    Args:
+        args.train_data, args.valid_data, args.test_data: str. Paths to the dataset.
+        args.split: str. format: "8,1,1". how to split train_data.
+        args.dataset_type: use to create the right datasets. 
+    """
+    make_dataset = partial(make_dataset_full, 
+                        create_dataset_function=create_dataset_function)
 
     world_size = torch.distributed.get_world_size(
         group=mpu.get_data_parallel_group())
@@ -290,22 +285,3 @@ class RandomMappingDataset(data.Dataset):
         rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
         index = rng.randint(len(self.wrapped_data))
         return self.wrapped_data[index]
-
-def detect_new_datasets(args):
-    if args.new_dataset_path is None:
-        return None
-    if not os.path.exists(args.new_dataset_path):
-        print('Warning: new_dataset_path not exists... skip detection.')
-        return None
-    current_datasets = [str(os.path.abspath(path)) for path in args.train_data]
-
-    found = []
-    for _p in os.listdir(args.new_dataset_path):
-        p = os.path.join(args.new_dataset_path, _p)
-        if (str(p).endswith('lmdb') or str(p).endswith('bin')) and not str(os.path.abspath(p)) in current_datasets:
-            found.append(p)
-    if len(found) == 0:
-        return None
-    else:
-        args.train_data = args.train_data + found
-        return make_loaders(args)    
diff --git a/data_utils/datasets.py b/data_utils/datasets.py
index d12d8f7..11836ad 100755
--- a/data_utils/datasets.py
+++ b/data_utils/datasets.py
@@ -11,26 +11,18 @@ import os
 import sys
 import math
 import random
-from tqdm import tqdm
 import logging
 
 
 import numpy as np
-import torch
-import torch.nn.functional as F
-from torchvision import datasets, transforms
 import pickle
-from collections import namedtuple
 
 from torch.utils.data import Dataset
-import lmdb
 
-from .unified_tokenizer import get_tokenizer
-from .templates import TextCodeTemplate
 
 logger = logging.getLogger(__name__)
 
-
+import lmdb
 class LMDBDataset(Dataset):
     def __init__(self, path, process_fn):
         self.env = lmdb.open(
@@ -55,9 +47,7 @@ class LMDBDataset(Dataset):
 
         with self.env.begin(write=False) as txn:
             key = str(idx).encode('utf-8')
-
             row = pickle.loads(txn.get(key))
-
             return self.process_fn(row)
 
 class BinaryDataset(Dataset):
@@ -80,70 +70,3 @@ class BinaryDataset(Dataset):
     def __getitem__(self, index):
         return self.process_fn(self.bin[index])
 
-def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): 
-    kwargs_to_dataset = {}      
-
-    tokenizer = get_tokenizer()
-    if args.layout[-1] > args.max_position_embeddings:
-        ml = args.layout[-1]
-    else:
-        ml = args.max_position_embeddings
-
-    def pad_to_len(ret):
-        if len(ret) < ml: # pad
-            return np.concatenate((ret, 
-                np.array([tokenizer['[PAD]']] * (ml - len(ret)))),
-                axis=0), len(ret)
-        else:
-            if len(ret) > ml:
-                logger.warning('Out of max len, truncated.')
-            return ret[:ml], ml
-
-    if dataset_type == 'TokenizedDataset':
-        # already tokenized when saved
-        def process_fn(row):
-            ret, attention_mask_sep = pad_to_len(row.flatten())
-            return {'text': ret, 
-                'loss_mask':  np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
-                }
-
-    elif dataset_type == 'TextCodeDataset':
-        def process_fn(row):
-            text, code = row[0], row[1].flatten()
-            ret = TextCodeTemplate(text, code)
-            ret, attention_mask_sep = pad_to_len(ret)
-            return {'text': ret, 
-                'loss_mask':  np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
-                }
-
-    elif dataset_type == 'CompactBinaryDataset':
-        layout = [64, 64+16**2, 64+16**2+32**2, 64+64**2+16**2+32**2] # FIXME
-        DS_CLASS = BinaryDataset
-        kwargs_to_dataset['length_per_sample'] = layout[-1]
-        def process_fn(row):
-            row = row.astype(np.int64)
-        
-            codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))]
-            
-            text = row[:layout[0]]
-            text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1]
-            n_pad = layout[0]-3-len(text)
-            parts = [
-                np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
-                TextCodeTemplate(text, codes[1]), # FIXME
-                *codes[2:] # FIXME
-            ]
-            ret = np.concatenate(parts, axis=0)
-            return {'text': ret, 
-                'loss_mask':  np.array([0] * (n_pad+1) + [1] * (len(ret) - n_pad - 1)) # don't predict [CLS]
-                }
-    elif dataset_type == 'BinaryDataset':
-        DS_CLASS = BinaryDataset
-        def process_fn(row):
-            loss_mask = (row >= 0).astype(np.int64)
-            return {'text': row.astype(np.int64), 
-                'loss_mask':  loss_mask
-                }
-
-    return DS_CLASS(path, process_fn, **kwargs_to_dataset)
-
diff --git a/data_utils/samplers.py b/data_utils/samplers.py
index fa8474e..90aac94 100755
--- a/data_utils/samplers.py
+++ b/data_utils/samplers.py
@@ -22,7 +22,7 @@ from torch.utils import data
 import numpy as np
 
 class RandomSampler(data.sampler.Sampler):
-    r"""
+    r""" 
     Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
     but this class lets the user set an epoch like DistributedSampler
     Samples elements randomly. If without replacement, then sample from a shuffled dataset.
@@ -139,14 +139,14 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
                 self.sampler.wrap_around -= (self.batch_size)
                 self.wrap_around += (len(batch))
                 self.wrap_around %= self.batch_size
-                if isinstance(self.sampler, TransposedSampler):
-                    for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)):
-                        if i == 0:
-                            continue
-                        batch.append(idx)
-                        new_batch_len = len(batch)
-                        if len(batch) == self.batch_size:
-                            break
+                # if isinstance(self.sampler, TransposedSampler):
+                #     for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)):
+                #         if i == 0:
+                #             continue
+                #         batch.append(idx)
+                #         new_batch_len = len(batch)
+                #         if len(batch) == self.batch_size:
+                #             break
             yield self._batch(batch)
         if self.wrap_last:
             self.sampler.wrap_around += self.batch_size
diff --git a/data_utils/sp_tokenizer.py b/data_utils/sp_tokenizer.py
deleted file mode 100755
index 01ac73b..0000000
--- a/data_utils/sp_tokenizer.py
+++ /dev/null
@@ -1,150 +0,0 @@
-"""
-from https://github.com/openai/gpt-2/, changed for chinese
-"""
-import json
-import os
-import sentencepiece as spm
-
-"""
-SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation 
-systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements 
-subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the 
-extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end 
-system that does not depend on language-specific pre/postprocessing.
-https://github.com/google/sentencepiece
-
-pip install sentencepiece
-
-or  git clone https://github.com/google/sentencepiece.git
-python setup.py install
-
-"""
-PRETRAINED_MODEL_FILE = "pretrained/chinese_sentencepiece/cog-pretrain.model"
-
-
-def get_pairs(word):
-    pairs = set()
-    prev_char = word[0]
-    for char in word[1:]:
-        pairs.add((prev_char, char))
-        prev_char = char
-    return pairs
-
-
-class Encoder:
-    def __init__(self, encoder, bpe_merges):
-        self.encoder = encoder
-        self.decoder = {v: k for k, v in self.encoder.items()}
-        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
-        self.cache = {}
-        self.max_len = 0
-
-    def bpe(self, token):
-        if token in self.cache:
-            return self.cache[token]
-        word = tuple(token)
-        pairs = get_pairs(word)
-        if not pairs:
-            return token
-
-        while True:
-            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
-            if bigram not in self.bpe_ranks:
-                break
-            first, second = bigram
-            new_word = []
-            i = 0
-            while i < len(word):
-                try:
-                    j = word.index(first, i)
-                    new_word.extend(word[i:j])
-                    i = j
-                except:
-                    new_word.extend(word[i:])
-                    break
-
-                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
-                    new_word.append(first + second)
-                    i += 2
-                else:
-                    new_word.append(word[i])
-                    i += 1
-            new_word = tuple(new_word)
-            word = new_word
-            if len(word) == 1:
-                break
-            else:
-                pairs = get_pairs(word)
-        word = ' '.join(word)
-        self.cache[token] = word
-        return word
-
-    def encode(self, text):
-        return [self.encoder.get(token, 1) for token in self.tokenize(text)]
-
-    def decode(self, tokens):
-        text = ''.join([self.decoder[token] for token in tokens])
-        return text
-
-    def tokenize(self, text):
-        bpe_tokens = []
-        bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
-        return bpe_tokens
-
-    def convert_tokens_to_ids(self, tokens):
-        return [self.encoder.get(token, 1) for token in tokens]
-
-
-class Encoder_SP:
-    def __init__(self, model_path):
-        self.sp = spm.SentencePieceProcessor()
-        self.sp.Load(model_path)
-        self.num_tokens = self.sp.vocab_size()
-
-    def encode(self, text):
-        """
-        text="...."
-        """
-        return self.sp.EncodeAsIds(text)
-
-    def decode(self, tokens):
-        """
-        tokens=[x1,x2,...]
-        """
-        text = [int(token) for token in tokens]
-        return self.sp.DecodeIds(text)
-
-    def tokenize(self, text):
-        return self.sp.EncodeAsPieces(text)
-
-    def convert_tokens_to_ids(self, tokens):
-        return [self.sp.PieceToId(token) for token in tokens]
-
-    def convert_token_to_id(self, token):
-        return self.sp.PieceToId(token)
-
-    def convert_id_to_token(self, idx):
-        return self.sp.IdToPiece(idx)
-
-
-def get_encoder(encoder_file, bpe_file):
-    # 以下是为了同一个函数入兼容sentencepiece
-    filepath, filename = os.path.split(encoder_file)
-    shotname, extension = os.path.splitext(filename)
-
-    if (".model" == extension) and (bpe_file == ""):
-        return Encoder_SP(encoder_file)
-    else:
-        with open(encoder_file, 'r', encoding="utf-8") as f:
-            encoder = json.load(f)
-        with open(bpe_file, 'r', encoding="utf-8") as f:
-            bpe_data = f.read()
-        bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
-        return Encoder(
-            encoder=encoder,
-            bpe_merges=bpe_merges,
-        )
-
-
-def from_pretrained():
-    return get_encoder(PRETRAINED_MODEL_FILE, "")
\ No newline at end of file
diff --git a/data_utils/templates.py b/data_utils/templates.py
deleted file mode 100755
index d5d4f99..0000000
--- a/data_utils/templates.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# -*- encoding: utf-8 -*-
-'''
-@File    :   templates.py
-@Time    :   2021/01/11 22:28:57
-@Author  :   Ming Ding 
-@Contact :   dm18@mails.tsinghua.edu.cn
-'''
-
-# here put the import lib
-import os
-import sys
-import math
-import random
-from tqdm import tqdm
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-from .unified_tokenizer import get_tokenizer
-from .vqvae_tokenizer import sqrt_int
-
-def concat_codes(*codes):
-    is_numpy = is_tensor = False
-    for code in codes:
-        if isinstance(code, np.ndarray):
-            is_numpy = True
-        if isinstance(code, torch.Tensor):
-            is_tensor = True
-            device = code.device
-    if is_tensor:
-        return torch.cat(
-            [
-                torch.tensor(code, device=device)
-                for code in codes
-            ]
-        )
-    elif is_numpy:
-        return np.concatenate(
-            [
-                np.array(code)
-                for code in codes
-            ],
-            axis=0
-        )
-    else:
-        ret = []
-        for code in codes:
-            ret = ret + code
-        return ret
-
-def TextCodeTemplate(text, code):
-    tokenizer = get_tokenizer()
-    if isinstance(text, str):
-        text_ids = [tokenizer['[ROI1]']] + tokenizer(text)
-    else:
-        text_ids = np.concatenate(
-                (
-                    np.array([tokenizer['[ROI1]']]),
-                    text,
-                ),
-                axis=0
-            )
-    code = tokenizer.wrap_code(code)
-    return concat_codes(text_ids, code)
-
-def Code2CodeTemplate(text, code0, code1):
-    tokenizer = get_tokenizer()
-    text_ids = tokenizer.parse_query(text) if isinstance(text, str) else text
-    code0 = tokenizer.wrap_code(code0)
-    code1 = tokenizer.wrap_code(code1, idx=2)
-    return concat_codes(text_ids, code0, code1)
-
-def PureTextTemplate(text):
-    tokenizer = get_tokenizer()
-    return tokenizer(text) + [tokenizer['[SEP]']]
-
-
-
-
-
-
-
diff --git a/data_utils/unified_tokenizer.py b/data_utils/unified_tokenizer.py
deleted file mode 100755
index 1ce520b..0000000
--- a/data_utils/unified_tokenizer.py
+++ /dev/null
@@ -1,212 +0,0 @@
-# -*- encoding: utf-8 -*-
-'''
-@File    :   unified_tokenizer.py
-@Time    :   2021/01/11 16:36:33
-@Author  :   Ming Ding 
-@Contact :   dm18@mails.tsinghua.edu.cn
-'''
-
-# here put the import lib
-import os
-import sys
-import math
-import random
-from tqdm import tqdm
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-from .sp_tokenizer import from_pretrained
-from .vqvae_tokenizer import VQVAETokenizer, sqrt_int
-
-class UnifiedTokenizer(object):
-    def __init__(self, img_tokenizer_path, device, img_tokenizer_num_tokens=None):
-        self.device = device
-        if img_tokenizer_path is None and img_tokenizer_num_tokens is not None:
-            # pretraining but only know the vocab size of VQVAE, which is developing fast
-            self.img_tokenizer = FakeTokenizer(img_tokenizer_num_tokens)
-        else:
-            self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device)
-        self.txt_tokenizer = from_pretrained()
-        self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens
-        self.raw_command_tokens = [
-            ('[PAD]', 0),
-            ('[BOI1]', 1), # Begin
-            ('[BOI2]', 2),
-            ('[BOI3]', 3),
-            ('[EOI1]', 4), # End
-            ('[EOI2]', 5),
-            ('[EOI3]', 6),
-            ('[ROI1]', 7), # Reference
-            ('[ROI2]', 8),
-            ('[ROI3]', 9),
-            ('[SEP]', 10),
-            ('[MASK]', 11),
-            ('[CLS]', 12),
-            ('[ENC]', 13),
-            ('[TINY]', 14), # 8 * 8
-            ('[SMALL]', 15), # 16 * 16
-            ('[BASE]', 16), # 32 * 32
-            ('[BIG]', 17), # 64 * 64
-            ('[POS0]', 18), # 58210
-            ('[POS1]', 19),
-            ('[POS2]', 20),
-            ('[POS3]', 21),
-            ('[POS4]', 22),
-            ('[POS5]', 23),
-            ('[POS6]', 24),
-            ('[POS7]', 25),
-            ('[POS8]', 26)
-            # Please leave the ``size tokens'' at the back of command tokens
-        ]
-        self.command_tokens = {
-            k: v + self.num_tokens
-            for k, v in self.raw_command_tokens
-        }
-        self.num_tokens += len(self.raw_command_tokens)
-    
-    def __getitem__(self, command_token):
-        return self.command_tokens[command_token]
-
-    def __len__(self):
-        """total number of tokens"""
-        return self.num_tokens
-
-    def __call__(self, inputs, process_fn=None):
-        """run preprocessing and encode inputs as Ids
-            CANNOT contain command tokens"""
-        if isinstance(inputs, torch.Tensor): # image
-            if len(inputs.shape) == 3:
-                inputs = inputs.unsqueeze(0)
-            return self.img_tokenizer.EncodeAsIds(inputs)
-        return self.EncodeAsIds(inputs, process_fn=process_fn)
-    
-    def EncodeAsIds(self, text, process_fn=None):
-        processed_text = text
-        if process_fn is not None:
-            processed_text = process_fn(processed_text)
-        ids = self.txt_tokenizer.encode(processed_text)
-        return [x + self.img_tokenizer.num_tokens for x in ids]
-    
-    def DecodeIds(self, ids):
-        ret, img_buffer, txt_buffer, ret_imgs = [], [], [], []
-        try:
-            for x in ids:
-                if self.num_tokens - len(self.raw_command_tokens) <= x:
-                    # command tokens
-                    token = self.raw_command_tokens[x - (self.num_tokens - len(self.raw_command_tokens))][0]
-                    if token.startswith('[EOI') and len(img_buffer) > 0:
-                        # dump image
-                        ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer))
-                        img_buffer = []
-                    if len(txt_buffer) > 0:
-                        # dump text
-                        ret.append(self.txt_tokenizer.decode(txt_buffer))
-                        txt_buffer = []
-                    ret.append(token)
-                elif x < self.img_tokenizer.num_tokens:
-                    img_buffer.append(x)
-                else:
-                    txt_buffer.append(x - self.img_tokenizer.num_tokens)
-            
-            if len(img_buffer) > 0:
-                # dump image
-                ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer))
-                img_buffer = []
-            if len(txt_buffer) > 0:
-                # dump text
-                ret.append(self.txt_tokenizer.decode(txt_buffer))
-                txt_buffer = []
-        except ValueError:
-            print('Value error in tokenization, skipping...')
-        return ret, ret_imgs
-
-    def wrap_code(self, code, idx=1):
-        s = sqrt_int(len(code))
-        prefix = {8:'[TINY]', 16:'[SMALL]', 32:'[BASE]', 64:'[BIG]'}[s]
-        boi = {1:'[BOI1]', 2: '[BOI2]', 3:'[BOI3]'}[idx]
-        eoi = {1:'[EOI1]', 2: '[EOI2]', 3:'[EOI3]'}[idx]
-    
-        if isinstance(code, list):
-            return [self.command_tokens[prefix], self.command_tokens[boi]] + \
-                code + [self.command_tokens[eoi]]
-        elif isinstance(code, np.ndarray):
-            return np.concatenate(
-                (
-                    np.array([self.command_tokens[prefix], self.command_tokens[boi]]),
-                    code,
-                    np.array([self.command_tokens[eoi]])
-                ),
-                axis=0
-            )
-        elif isinstance(code, torch.Tensor):
-            return torch.cat(
-                (
-                    torch.tensor([self.command_tokens[prefix], self.command_tokens[boi]]),
-                    code, 
-                    np.array([self.command_tokens[eoi]])
-                )
-            )
-        else:
-            raise ValueError('')
-
-    def parse_query(self, query, img_size=256):
-        text_buffer = []
-        ret = []
-        for part in query.split(' '):
-            if part in self.command_tokens:
-                if len(text_buffer) > 0:
-                    # dump text ids
-                    ret.extend(self.EncodeAsIds(' '.join(text_buffer)))
-                    text_buffer = []
-                if part == '[MASK]':
-                    ret.append(-1)
-                else:
-                    ret.append(self.command_tokens[part])
-            elif part.startswith('[MASK]*'): # special lang *N
-                c = int(part[7:])
-                assert c > 0
-                if len(text_buffer) > 0:
-                    # dump text ids
-                    ret.extend(self.EncodeAsIds(' '.join(text_buffer)))
-                    text_buffer = []
-                ret.extend([-1] * c)
-            elif part.startswith('[Image'): # [Image*N]path
-                c = part[6:]
-                assert len(c) > 0
-                num_codes, img_path = c.split(']')
-                if num_codes == '':
-                    num_codes = 1024
-                else:
-                    num_codes = int(num_codes)
-                
-                raw_img = self.img_tokenizer.read_img(img_path, img_size=img_size)
-                img_codes = self.img_tokenizer.EncodeAsIds(raw_img) # [1, 32*32]
-                img_codes[0, num_codes:] = -1
-                img_codes = img_codes[0].tolist()
-                ret.extend(img_codes)
-            else:
-                text_buffer.append(part)
-
-        if len(text_buffer) > 0:
-            # dump text ids
-            ret.extend(self.EncodeAsIds(' '.join(text_buffer)))
-            text_buffer = []
-        return ret
-
-def get_tokenizer(args=None):
-    if not hasattr(get_tokenizer, 'tokenizer'):
-        # the first time to load the tokenizer, specify img_tokenizer_path
-        get_tokenizer.tokenizer = UnifiedTokenizer(
-            args.img_tokenizer_path, 
-            device=torch.cuda.current_device(),
-            img_tokenizer_num_tokens=args.img_tokenizer_num_tokens
-            )
-    return get_tokenizer.tokenizer
-
-class FakeTokenizer(object):
-    def __init__(self, num_tokens):
-        self.num_tokens = num_tokens
-    def __len__(self):
-        return self.num_tokens
\ No newline at end of file
diff --git a/data_utils/vqvae_tokenizer.py b/data_utils/vqvae_tokenizer.py
deleted file mode 100755
index 56ee251..0000000
--- a/data_utils/vqvae_tokenizer.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# -*- encoding: utf-8 -*-
-'''
-@File    :   vqvae_tokenizer.py
-@Time    :   2021/01/11 17:57:43
-@Author  :   Ming Ding 
-@Contact :   dm18@mails.tsinghua.edu.cn
-'''
-
-# here put the import lib
-import os
-import sys
-import math
-import random
-from tqdm import tqdm
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-
-from vqvae import new_model, img2code, code2img
-from torchvision import transforms
-from PIL import Image
-
-def is_exp2(x):
-    t = math.log2(x)
-    return abs(t - int(t)) < 1e-4
-def sqrt_int(x):
-    r = int(math.sqrt(x) + 1e-4)
-    assert r * r == x
-    return r
-
-class VQVAETokenizer(object):
-    def __init__(self, 
-            model_path, 
-            device='cuda'
-        ):
-        ckpt = torch.load(model_path, map_location=torch.device(device))
-
-        model = new_model()
-
-        if list(ckpt.keys())[0].startswith('module.'):
-            ckpt = {k[7:]: v for k, v in ckpt.items()}
-
-        model.load_state_dict(ckpt)
-        model = model.to(device)
-        model.eval()
-
-        self.model = model
-        self.device = device
-        self.image_tokens = model.quantize_t.n_embed
-        self.num_tokens = model.quantize_t.n_embed
-        self.tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800])
-
-    def __len__(self):
-        return self.num_tokens
-
-    def EncodeAsIds(self, img, add_normalization=False):
-        assert len(img.shape) == 4 # [b, c, h, w]
-        if add_normalization:
-            img = self.tr_normalize(img)
-        return img2code(self.model, img)
-
-    def DecodeIds(self, code, shape=None):
-        if shape is None:
-            if isinstance(code, list):
-                code = torch.tensor(code, device=self.device)
-            s = sqrt_int(len(code.view(-1)))
-            assert s * s == len(code.view(-1))
-            shape = (1, s, s)
-        code = code.view(shape)
-        out = code2img(self.model, code)
-        return out
-
-    def read_img(self, path, img_size=256):
-        tr = transforms.Compose([
-            transforms.Resize(img_size), 
-            transforms.CenterCrop(img_size),
-            transforms.ToTensor(),
-        ])
-        img = tr(Image.open(path))
-        if img.shape[0] == 4:
-            img = img[:-1]
-        img = self.tr_normalize(img)
-        img = img.unsqueeze(0).float().to(self.device) # size [1, 3, h, w]
-        return img  
\ No newline at end of file
diff --git a/mpu/initialize.py b/mpu/initialize.py
index 0a3e15a..3597c2c 100755
--- a/mpu/initialize.py
+++ b/mpu/initialize.py
@@ -36,7 +36,7 @@ def initialize_model_parallel(model_parallel_size_):
 
     Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
     use 2 GPUs to parallelize the model. The present function will
-    create 4 model parallel groups and 2 data parallel grous as:
+    create 4 model parallel groups and 2 data parallel groups as:
         4 model parallel groups:
             [g0, g1], [g2, g3], [g4, g5], [g6, g7]
         2 data parallel groups:
diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py
index 3b8ccc4..a8a0ea2 100755
--- a/pretrain_gpt2.py
+++ b/pretrain_gpt2.py
@@ -12,11 +12,15 @@ import sys
 import math
 import random
 import torch
+import numpy as np
 
 import mpu
 from arguments import get_args
 from model.base_model import BaseModel
 from training.deepspeed_training import main
+from data_utils import BinaryDataset
+from tokenization import get_tokenizer
+from tokenization.cogview import TextCodeTemplate
 
 def get_masks_and_position_ids(data,
                             loss_mask=None,
@@ -99,6 +103,27 @@ def forward_step(data_iterator, model, args, timers):
     
     return loss, {}
 
+def create_dataset_function(path, args):
+    tokenizer = get_tokenizer()
+    layout = [64, 64+16**2, 64+16**2+32**2, 64+64**2+16**2+32**2] # FIXME
+    def process_fn(row):
+        row = row.astype(np.int64)
+        codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))]
+        
+        text = row[:layout[0]]
+        text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1]
+        merged = TextCodeTemplate(text, codes[1], tokenizer)
+        n_pad = args.max_sequence_length - len(merged)
+        parts = [
+            merged, 
+            np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64)
+        ]
+        ret = np.concatenate(parts, axis=0)
+        return {'text': ret, 
+            'loss_mask': np.array([1]*len(merged) + [0]*n_pad)
+            }
+    return BinaryDataset(path, process_fn, length_per_sample=layout[-1])
+
 if __name__ == '__main__':
     args = get_args()
-    main(args, model_cls=BaseModel, forward_step=forward_step)
+    main(args, model_cls=BaseModel, forward_step=forward_step, create_dataset_function=create_dataset_function)
diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py
index 95d8ac4..5519003 100644
--- a/training/deepspeed_training.py
+++ b/training/deepspeed_training.py
@@ -40,11 +40,12 @@ from data_utils import make_loaders, get_tokenizer
 
 
 
-def main(args, model_cls, forward_step, init_step=None):
+def main(args, model_cls, forward_step_function, create_dataset_function, init_function=None):
     """Main training program."""
     hooks = {
-        'forward_step': forward_step,
-        'init_step': init_step
+        'forward_step': forward_step_function,
+        'init_function': init_function,
+        'create_dataset_function': create_dataset_function
         }
     
     torch.backends.cuda.matmul.allow_tf32 = False
@@ -63,7 +64,7 @@ def main(args, model_cls, forward_step, init_step=None):
     # init tokenizer
     tokenizer = get_tokenizer(args)
     # Data stuff.
-    train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args)
+    train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args, hooks['create_dataset_function'])
 
     # Model, optimizer, and learning rate.
     model, optimizer = setup_model_and_optimizer(args, model_cls)
@@ -109,8 +110,8 @@ def main(args, model_cls, forward_step, init_step=None):
         val_data_iterator = None
         
     # init hook before training
-    if hooks['init_func'] is not None:
-        hooks['init_func'](args, model, optimizer)
+    if hooks['init_function'] is not None:
+        hooks['init_function'](args, model, optimizer)
 
     # training 
     iteration = 0
@@ -525,14 +526,14 @@ def set_random_seed(seed):
         mpu.model_parallel_cuda_manual_seed(seed)
 
 
-def get_train_val_test_data(args):
+def get_train_val_test_data(args, create_dataset_function):
     """Load the data on rank zero and boradcast number of tokens to all GPUS."""
 
     (train_data, val_data, test_data) = (None, None, None)
 
     # Data loader only on rank 0 of each model parallel group.
     if mpu.get_model_parallel_rank() == 0:
-        train_data, val_data, test_data = make_loaders(args)
+        train_data, val_data, test_data = make_loaders(args, create_dataset_function)
         num_tokens = get_tokenizer().num_tokens
 
         before = num_tokens
-- 
GitLab