diff --git a/tokenization/__init__.py b/tokenization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6b481dc7ec60f9aed1acae5963726187cce81bb
--- /dev/null
+++ b/tokenization/__init__.py
@@ -0,0 +1,36 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   __init__.py
+@Time    :   2021/10/06 17:58:04
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+
+def get_tokenizer(args=None):
+    if not hasattr(get_tokenizer, 'tokenizer'):
+        # the first time to load the tokenizer
+        if args.tokenizer_type == 'cogview':
+            from .cogview import UnifiedTokenizer
+            get_tokenizer.tokenizer = UnifiedTokenizer(
+                args.img_tokenizer_path, 
+                device=torch.cuda.current_device()
+                )
+        elif args.tokenizer_type == 'glm':
+            pass
+        else:
+            assert args.vocab_size > 0
+            get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
+    return get_tokenizer.tokenizer
+
+class FakeTokenizer(object):
+    def __init__(self, num_tokens):
+        self.num_tokens = num_tokens
+    def __len__(self):
+        return self.num_tokens
\ No newline at end of file
diff --git a/tokenization/cogview/__init__.py b/tokenization/cogview/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4df817a4fc7d2d474303795a668bf6f7e9c67699
--- /dev/null
+++ b/tokenization/cogview/__init__.py
@@ -0,0 +1,16 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   __init__.py
+@Time    :   2021/10/06 18:21:15
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+from .unified_tokenizer import UnifiedTokenizer
+from .templates import *
\ No newline at end of file
diff --git a/tokenization/cogview/sp_tokenizer.py b/tokenization/cogview/sp_tokenizer.py
new file mode 100755
index 0000000000000000000000000000000000000000..01ac73b248d34642c043048476b597a6b26a6a6d
--- /dev/null
+++ b/tokenization/cogview/sp_tokenizer.py
@@ -0,0 +1,150 @@
+"""
+from https://github.com/openai/gpt-2/, changed for chinese
+"""
+import json
+import os
+import sentencepiece as spm
+
+"""
+SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation 
+systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements 
+subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the 
+extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end 
+system that does not depend on language-specific pre/postprocessing.
+https://github.com/google/sentencepiece
+
+pip install sentencepiece
+
+or  git clone https://github.com/google/sentencepiece.git
+python setup.py install
+
+"""
+PRETRAINED_MODEL_FILE = "pretrained/chinese_sentencepiece/cog-pretrain.model"
+
+
+def get_pairs(word):
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+class Encoder:
+    def __init__(self, encoder, bpe_merges):
+        self.encoder = encoder
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+        self.cache = {}
+        self.max_len = 0
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token)
+        pairs = get_pairs(word)
+        if not pairs:
+            return token
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        return [self.encoder.get(token, 1) for token in self.tokenize(text)]
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        return text
+
+    def tokenize(self, text):
+        bpe_tokens = []
+        bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
+        return bpe_tokens
+
+    def convert_tokens_to_ids(self, tokens):
+        return [self.encoder.get(token, 1) for token in tokens]
+
+
+class Encoder_SP:
+    def __init__(self, model_path):
+        self.sp = spm.SentencePieceProcessor()
+        self.sp.Load(model_path)
+        self.num_tokens = self.sp.vocab_size()
+
+    def encode(self, text):
+        """
+        text="...."
+        """
+        return self.sp.EncodeAsIds(text)
+
+    def decode(self, tokens):
+        """
+        tokens=[x1,x2,...]
+        """
+        text = [int(token) for token in tokens]
+        return self.sp.DecodeIds(text)
+
+    def tokenize(self, text):
+        return self.sp.EncodeAsPieces(text)
+
+    def convert_tokens_to_ids(self, tokens):
+        return [self.sp.PieceToId(token) for token in tokens]
+
+    def convert_token_to_id(self, token):
+        return self.sp.PieceToId(token)
+
+    def convert_id_to_token(self, idx):
+        return self.sp.IdToPiece(idx)
+
+
+def get_encoder(encoder_file, bpe_file):
+    # 以下是为了同一个函数入兼容sentencepiece
+    filepath, filename = os.path.split(encoder_file)
+    shotname, extension = os.path.splitext(filename)
+
+    if (".model" == extension) and (bpe_file == ""):
+        return Encoder_SP(encoder_file)
+    else:
+        with open(encoder_file, 'r', encoding="utf-8") as f:
+            encoder = json.load(f)
+        with open(bpe_file, 'r', encoding="utf-8") as f:
+            bpe_data = f.read()
+        bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
+        return Encoder(
+            encoder=encoder,
+            bpe_merges=bpe_merges,
+        )
+
+
+def from_pretrained():
+    return get_encoder(PRETRAINED_MODEL_FILE, "")
\ No newline at end of file
diff --git a/tokenization/cogview/templates.py b/tokenization/cogview/templates.py
new file mode 100755
index 0000000000000000000000000000000000000000..4d665785dcdb0d011abf490fbb8725f76f6467ac
--- /dev/null
+++ b/tokenization/cogview/templates.py
@@ -0,0 +1,77 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   templates.py
+@Time    :   2021/01/11 22:28:57
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+from tqdm import tqdm
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+def concat_codes(*codes):
+    is_numpy = is_tensor = False
+    for code in codes:
+        if isinstance(code, np.ndarray):
+            is_numpy = True
+        if isinstance(code, torch.Tensor):
+            is_tensor = True
+            device = code.device
+    if is_tensor:
+        return torch.cat(
+            [
+                torch.tensor(code, device=device)
+                for code in codes
+            ]
+        )
+    elif is_numpy:
+        return np.concatenate(
+            [
+                np.array(code)
+                for code in codes
+            ],
+            axis=0
+        )
+    else:
+        ret = []
+        for code in codes:
+            ret = ret + code
+        return ret
+
+def TextCodeTemplate(text, code, tokenizer):
+    if isinstance(text, str):
+        text_ids = [tokenizer['[ROI1]']] + tokenizer(text)
+    else:
+        text_ids = np.concatenate(
+                (
+                    np.array([tokenizer['[ROI1]']]),
+                    text,
+                ),
+                axis=0
+            )
+    code = tokenizer.wrap_code(code)
+    return concat_codes(text_ids, code)
+
+def Code2CodeTemplate(text, code0, code1, tokenizer):
+    text_ids = tokenizer.parse_query(text) if isinstance(text, str) else text
+    code0 = tokenizer.wrap_code(code0)
+    code1 = tokenizer.wrap_code(code1, idx=2)
+    return concat_codes(text_ids, code0, code1)
+
+def PureTextTemplate(text, tokenizer):
+    return tokenizer(text) + [tokenizer['[SEP]']]
+
+
+
+
+
+
+
diff --git a/tokenization/cogview/unified_tokenizer.py b/tokenization/cogview/unified_tokenizer.py
new file mode 100755
index 0000000000000000000000000000000000000000..bf66d9a7cb4dbb0e43582ea81b341a001d7651f2
--- /dev/null
+++ b/tokenization/cogview/unified_tokenizer.py
@@ -0,0 +1,193 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   unified_tokenizer.py
+@Time    :   2021/01/11 16:36:33
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+from tqdm import tqdm
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .sp_tokenizer import from_pretrained
+from .vqvae_tokenizer import VQVAETokenizer, sqrt_int
+
+class UnifiedTokenizer(object):
+    def __init__(self, img_tokenizer_path, device):
+        self.device = device
+        self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device)
+        self.txt_tokenizer = from_pretrained()
+        self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens
+        self.raw_command_tokens = [
+            ('[PAD]', 0),
+            ('[BOI1]', 1), # Begin
+            ('[BOI2]', 2),
+            ('[BOI3]', 3),
+            ('[EOI1]', 4), # End
+            ('[EOI2]', 5),
+            ('[EOI3]', 6),
+            ('[ROI1]', 7), # Reference
+            ('[ROI2]', 8),
+            ('[ROI3]', 9),
+            ('[SEP]', 10),
+            ('[MASK]', 11),
+            ('[CLS]', 12),
+            ('[ENC]', 13),
+            ('[TINY]', 14), # 8 * 8
+            ('[SMALL]', 15), # 16 * 16
+            ('[BASE]', 16), # 32 * 32
+            ('[BIG]', 17), # 64 * 64
+            ('[POS0]', 18), # 58210
+            ('[POS1]', 19),
+            ('[POS2]', 20),
+            ('[POS3]', 21),
+            ('[POS4]', 22),
+            ('[POS5]', 23),
+            ('[POS6]', 24),
+            ('[POS7]', 25),
+            ('[POS8]', 26)
+            # Please leave the ``size tokens'' at the back of command tokens
+        ]
+        self.command_tokens = {
+            k: v + self.num_tokens
+            for k, v in self.raw_command_tokens
+        }
+        self.num_tokens += len(self.raw_command_tokens)
+    
+    def __getitem__(self, command_token):
+        return self.command_tokens[command_token]
+
+    def __len__(self):
+        """total number of tokens"""
+        return self.num_tokens
+
+    def __call__(self, inputs, process_fn=None):
+        """run preprocessing and encode inputs as Ids
+            CANNOT contain command tokens"""
+        if isinstance(inputs, torch.Tensor): # image
+            if len(inputs.shape) == 3:
+                inputs = inputs.unsqueeze(0)
+            return self.img_tokenizer.EncodeAsIds(inputs)
+        return self.EncodeAsIds(inputs, process_fn=process_fn)
+    
+    def EncodeAsIds(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        ids = self.txt_tokenizer.encode(processed_text)
+        return [x + self.img_tokenizer.num_tokens for x in ids]
+    
+    def DecodeIds(self, ids):
+        ret, img_buffer, txt_buffer, ret_imgs = [], [], [], []
+        try:
+            for x in ids:
+                if self.num_tokens - len(self.raw_command_tokens) <= x:
+                    # command tokens
+                    token = self.raw_command_tokens[x - (self.num_tokens - len(self.raw_command_tokens))][0]
+                    if token.startswith('[EOI') and len(img_buffer) > 0:
+                        # dump image
+                        ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer))
+                        img_buffer = []
+                    if len(txt_buffer) > 0:
+                        # dump text
+                        ret.append(self.txt_tokenizer.decode(txt_buffer))
+                        txt_buffer = []
+                    ret.append(token)
+                elif x < self.img_tokenizer.num_tokens:
+                    img_buffer.append(x)
+                else:
+                    txt_buffer.append(x - self.img_tokenizer.num_tokens)
+            
+            if len(img_buffer) > 0:
+                # dump image
+                ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer))
+                img_buffer = []
+            if len(txt_buffer) > 0:
+                # dump text
+                ret.append(self.txt_tokenizer.decode(txt_buffer))
+                txt_buffer = []
+        except ValueError:
+            print('Value error in tokenization, skipping...')
+        return ret, ret_imgs
+
+    def wrap_code(self, code, idx=1):
+        s = sqrt_int(len(code))
+        prefix = {8:'[TINY]', 16:'[SMALL]', 32:'[BASE]', 64:'[BIG]'}[s]
+        boi = {1:'[BOI1]', 2: '[BOI2]', 3:'[BOI3]'}[idx]
+        eoi = {1:'[EOI1]', 2: '[EOI2]', 3:'[EOI3]'}[idx]
+    
+        if isinstance(code, list):
+            return [self.command_tokens[prefix], self.command_tokens[boi]] + \
+                code + [self.command_tokens[eoi]]
+        elif isinstance(code, np.ndarray):
+            return np.concatenate(
+                (
+                    np.array([self.command_tokens[prefix], self.command_tokens[boi]]),
+                    code,
+                    np.array([self.command_tokens[eoi]])
+                ),
+                axis=0
+            )
+        elif isinstance(code, torch.Tensor):
+            return torch.cat(
+                (
+                    torch.tensor([self.command_tokens[prefix], self.command_tokens[boi]]),
+                    code, 
+                    np.array([self.command_tokens[eoi]])
+                )
+            )
+        else:
+            raise ValueError('')
+
+    def parse_query(self, query, img_size=256):
+        text_buffer = []
+        ret = []
+        for part in query.split(' '):
+            if part in self.command_tokens:
+                if len(text_buffer) > 0:
+                    # dump text ids
+                    ret.extend(self.EncodeAsIds(' '.join(text_buffer)))
+                    text_buffer = []
+                if part == '[MASK]':
+                    ret.append(-1)
+                else:
+                    ret.append(self.command_tokens[part])
+            elif part.startswith('[MASK]*'): # special lang *N
+                c = int(part[7:])
+                assert c > 0
+                if len(text_buffer) > 0:
+                    # dump text ids
+                    ret.extend(self.EncodeAsIds(' '.join(text_buffer)))
+                    text_buffer = []
+                ret.extend([-1] * c)
+            elif part.startswith('[Image'): # [Image*N]path
+                c = part[6:]
+                assert len(c) > 0
+                num_codes, img_path = c.split(']')
+                if num_codes == '':
+                    num_codes = 1024
+                else:
+                    num_codes = int(num_codes)
+                
+                raw_img = self.img_tokenizer.read_img(img_path, img_size=img_size)
+                img_codes = self.img_tokenizer.EncodeAsIds(raw_img) # [1, 32*32]
+                img_codes[0, num_codes:] = -1
+                img_codes = img_codes[0].tolist()
+                ret.extend(img_codes)
+            else:
+                text_buffer.append(part)
+
+        if len(text_buffer) > 0:
+            # dump text ids
+            ret.extend(self.EncodeAsIds(' '.join(text_buffer)))
+            text_buffer = []
+        return ret
+
diff --git a/tokenization/cogview/vqvae_tokenizer.py b/tokenization/cogview/vqvae_tokenizer.py
new file mode 100755
index 0000000000000000000000000000000000000000..56ee251126fdcb901d4b88cea114342b1dfccdb7
--- /dev/null
+++ b/tokenization/cogview/vqvae_tokenizer.py
@@ -0,0 +1,86 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   vqvae_tokenizer.py
+@Time    :   2021/01/11 17:57:43
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+from tqdm import tqdm
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+from vqvae import new_model, img2code, code2img
+from torchvision import transforms
+from PIL import Image
+
+def is_exp2(x):
+    t = math.log2(x)
+    return abs(t - int(t)) < 1e-4
+def sqrt_int(x):
+    r = int(math.sqrt(x) + 1e-4)
+    assert r * r == x
+    return r
+
+class VQVAETokenizer(object):
+    def __init__(self, 
+            model_path, 
+            device='cuda'
+        ):
+        ckpt = torch.load(model_path, map_location=torch.device(device))
+
+        model = new_model()
+
+        if list(ckpt.keys())[0].startswith('module.'):
+            ckpt = {k[7:]: v for k, v in ckpt.items()}
+
+        model.load_state_dict(ckpt)
+        model = model.to(device)
+        model.eval()
+
+        self.model = model
+        self.device = device
+        self.image_tokens = model.quantize_t.n_embed
+        self.num_tokens = model.quantize_t.n_embed
+        self.tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800])
+
+    def __len__(self):
+        return self.num_tokens
+
+    def EncodeAsIds(self, img, add_normalization=False):
+        assert len(img.shape) == 4 # [b, c, h, w]
+        if add_normalization:
+            img = self.tr_normalize(img)
+        return img2code(self.model, img)
+
+    def DecodeIds(self, code, shape=None):
+        if shape is None:
+            if isinstance(code, list):
+                code = torch.tensor(code, device=self.device)
+            s = sqrt_int(len(code.view(-1)))
+            assert s * s == len(code.view(-1))
+            shape = (1, s, s)
+        code = code.view(shape)
+        out = code2img(self.model, code)
+        return out
+
+    def read_img(self, path, img_size=256):
+        tr = transforms.Compose([
+            transforms.Resize(img_size), 
+            transforms.CenterCrop(img_size),
+            transforms.ToTensor(),
+        ])
+        img = tr(Image.open(path))
+        if img.shape[0] == 4:
+            img = img[:-1]
+        img = self.tr_normalize(img)
+        img = img.unsqueeze(0).float().to(self.device) # size [1, 3, h, w]
+        return img  
\ No newline at end of file