From 78b7ca5f0f17e8bacad40bb1ceccc893a2e19204 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Fri, 29 Oct 2021 19:14:57 +0000
Subject: [PATCH] pass cogview generate

---
 .gitignore                                    |   4 +-
 CHANGE_LOG.md                                 |  27 +++
 MANIFEST.in                                   |   3 +
 pretrained/cogview/placeholder => README.md   |   0
 SwissArmyTransformer/__init__.py              |   1 -
 SwissArmyTransformer/tokenization/__init__.py |   2 +-
 .../tokenization/cogview/vqvae/api.py         |  10 +-
 SwissArmyTransformer/training/__init__.py     |   2 +-
 SwissArmyTransformer/training/model_io.py     |   2 +-
 examples/cogview/inference_cogview.py         |   2 +-
 examples/cogview/inference_cogview_caps.py    |   4 +-
 .../cogview/scripts/text2image_cogview.sh     |   4 +-
 examples/cogview2/inference_cogview2.py       |   4 +-
 examples/glm/inference_glm.py                 |   4 +-
 inference_glm_old.py                          | 176 -----------------
 move_images.py                                |  28 ---
 move_weights.py                               | 185 ------------------
 pretrained/vqvae/placeholder                  |   0
 setup.py                                      |  33 ++++
 19 files changed, 83 insertions(+), 408 deletions(-)
 create mode 100644 MANIFEST.in
 rename pretrained/cogview/placeholder => README.md (100%)
 delete mode 100644 inference_glm_old.py
 delete mode 100644 move_images.py
 delete mode 100644 move_weights.py
 delete mode 100644 pretrained/vqvae/placeholder
 create mode 100644 setup.py

diff --git a/.gitignore b/.gitignore
index a87af0c..d129f41 100755
--- a/.gitignore
+++ b/.gitignore
@@ -18,4 +18,6 @@ input*.txt
 coco_scores/*
 checkpoints/
 *coco*
-runs
\ No newline at end of file
+runs
+dist/
+*.egg-info
\ No newline at end of file
diff --git a/CHANGE_LOG.md b/CHANGE_LOG.md
index ad82b80..67cdf6f 100644
--- a/CHANGE_LOG.md
+++ b/CHANGE_LOG.md
@@ -1 +1,28 @@
 # 2021.10.29
+1. change `mixins` from `ModuleList` to `ModuleDict`
+2. return tokens and mems in `fill_sequence`, and mems becomes a tensor.
+3. `CachedAutoRegressiveMixin`
+## How to migrate old SAT ckpt to new version?
+Example:
+```python
+import torch
+old = torch.load('xxxxx/mp_rank_00_model_states.pt.old', map_location='cpu')
+
+# replace names, mixins index to keys
+oldm = old['module']
+for k in list(oldm.keys()):
+    if k.startswith('mixins.0'):
+        new_k = k.replace('mixins.0', 'mixins.extra_position_embedding')
+    elif k.startswith('mixins.1'):
+        new_k = k.replace('mixins.1', 'mixins.attention_plus')
+    else:
+        continue
+    oldm[new_k] = oldm[k]
+    del oldm[k]
+# save to destination    
+torch.save(old, 'xxxxx/mp_rank_00_model_states.pt')
+
+```
+
+
+
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..db51f02
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,3 @@
+include requirements.txt
+global-exclude __pycache__/*
+graft SwissArmyTransformer/tokenization/embed_assets
\ No newline at end of file
diff --git a/pretrained/cogview/placeholder b/README.md
similarity index 100%
rename from pretrained/cogview/placeholder
rename to README.md
diff --git a/SwissArmyTransformer/__init__.py b/SwissArmyTransformer/__init__.py
index 6e1974e..dd0d987 100644
--- a/SwissArmyTransformer/__init__.py
+++ b/SwissArmyTransformer/__init__.py
@@ -1,4 +1,3 @@
-__version__ = '0.1'
 from .arguments import get_args
 from .training import load_checkpoint, set_random_seed, initialize_distributed
 from .tokenization import get_tokenizer
diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index 422c03e..64f9e36 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -29,7 +29,7 @@ def _export_vocab_size_to_args(args, original_num_tokens):
                  'tokens (new size: {})'.format(
         before, after - before, after))
     args.vocab_size = after
-    print_rank_0("prepare tokenizer done", flush=True)
+    print_rank_0("prepare tokenizer done")
     return tokenizer
 
 def get_tokenizer(args=None, outer_tokenizer=None):
diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/api.py b/SwissArmyTransformer/tokenization/cogview/vqvae/api.py
index 7d2fd03..060760a 100755
--- a/SwissArmyTransformer/tokenization/cogview/vqvae/api.py
+++ b/SwissArmyTransformer/tokenization/cogview/vqvae/api.py
@@ -26,7 +26,7 @@ def new_module(config):
     if not "target" in config:
         raise KeyError("Expected key `target` to instantiate.")
     module, cls = config.get("target").rsplit(".", 1)
-    model = getattr(importlib.import_module(module, package=None), cls)(**config.get("params", dict()))
+    model = getattr(importlib.import_module(module, package=__package__), cls)(**config.get("params", dict()))
 
     device = config.get("device", "cpu")
     model = model.to(device)
@@ -45,7 +45,7 @@ def new_module(config):
 
 def load_decoder_default(device=0, path="pretrained/vqvae/l1+ms-ssim+revd_percep.pt"):
     # exp: load currently best decoder
-    target = "vqvae.vqvae_diffusion.Decoder"
+    target = ".vqvae_diffusion.Decoder"
     params = {
         "double_z": False,
         "z_channels": 256,
@@ -100,7 +100,7 @@ def load_model_default(device=0,
     }
 
     config = {
-        'target': "vqvae.vqvae_zc.VQVAE",
+        'target': ".vqvae_zc.VQVAE",
         'params': params,
         'ckpt': path,
         'device': device
@@ -116,7 +116,7 @@ def test_decode(configs, testcase, device=0, output_path=None):
         output_path = os.path.join("sample", f"{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg")
 
     quantize_config = {
-        "target": "vqvae.vqvae_zc.Quantize",
+        "target": ".vqvae_zc.Quantize",
         "params": {
             "dim": 256,
             "n_embed": 8192,
@@ -149,7 +149,7 @@ def test_decode_default(device=0):
     # testing 3 decoders: original/l1+ms-ssim/l1+ms-ssim+perceptual
     configs = [
         {
-            "target": "vqvae.vqvae_zc.Decoder",
+            "target": ".vqvae_zc.Decoder",
             "params": {
                 "in_channel": 256, 
                 "out_channel": 3,
diff --git a/SwissArmyTransformer/training/__init__.py b/SwissArmyTransformer/training/__init__.py
index dc0337f..4d462e6 100644
--- a/SwissArmyTransformer/training/__init__.py
+++ b/SwissArmyTransformer/training/__init__.py
@@ -1,2 +1,2 @@
-from .deepspeed_training import initialize_distributed, set_random_seed, prepare_tokenizer
+from .deepspeed_training import initialize_distributed, set_random_seed
 from .model_io import load_checkpoint
\ No newline at end of file
diff --git a/SwissArmyTransformer/training/model_io.py b/SwissArmyTransformer/training/model_io.py
index 92becb9..d66e40e 100644
--- a/SwissArmyTransformer/training/model_io.py
+++ b/SwissArmyTransformer/training/model_io.py
@@ -14,7 +14,7 @@ import random
 import torch
 import numpy as np
 
-import SwissArmyTransformer.mpu
+from SwissArmyTransformer import mpu
 from .utils import print_rank_0
 
 
diff --git a/examples/cogview/inference_cogview.py b/examples/cogview/inference_cogview.py
index 8d2a555..f546e37 100644
--- a/examples/cogview/inference_cogview.py
+++ b/examples/cogview/inference_cogview.py
@@ -62,7 +62,7 @@ def main(args):
                     batch_size=min(args.batch_size, mbz),
                     strategy=strategy,
                     log_attention_weights=log_attention_weights
-                    )
+                    )[0]
                 )
         output_tokens = torch.cat(output_list, dim=0)
         # decoding
diff --git a/examples/cogview/inference_cogview_caps.py b/examples/cogview/inference_cogview_caps.py
index 7b2e4e7..db6a35e 100644
--- a/examples/cogview/inference_cogview_caps.py
+++ b/examples/cogview/inference_cogview_caps.py
@@ -16,13 +16,13 @@ import argparse
 
 from arguments import get_args
 from model.base_model import BaseModel
-from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from training import load_checkpoint, initialize_distributed, set_random_seed
 from generation.autoregressive_sampling import get_masks_and_position_ids
 from generation.utils import timed_name, save_multiple_images, generate_continually
 
 def main(args):
     initialize_distributed(args)
-    tokenizer = prepare_tokenizer(args)
+    tokenizer = get_tokenizer(args)
     # build model 
     model = BaseModel(args)
     if args.fp16:
diff --git a/examples/cogview/scripts/text2image_cogview.sh b/examples/cogview/scripts/text2image_cogview.sh
index bcb1ecd..9bb2213 100755
--- a/examples/cogview/scripts/text2image_cogview.sh
+++ b/examples/cogview/scripts/text2image_cogview.sh
@@ -1,6 +1,6 @@
 #!/bin/bash
 
-CHECKPOINT_PATH=pretrained/cogview/cogview-base
+CHECKPOINT_PATH=/workspace/dm/SwissArmyTransformer/pretrained/cogview/cogview-base
 NLAYERS=48
 NHIDDEN=2560
 NATT=40
@@ -17,7 +17,7 @@ script_dir=$(dirname $script_path)
 
 MASTER_PORT=${MASTER_PORT} python inference_cogview.py \
        --tokenizer-type cogview \
-       --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
+       --img-tokenizer-path /workspace/dm/SwissArmyTransformer/pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
        --mode inference \
        --distributed-backend nccl \
        --max-sequence-length 1089 \
diff --git a/examples/cogview2/inference_cogview2.py b/examples/cogview2/inference_cogview2.py
index c69e589..b3fb7e0 100644
--- a/examples/cogview2/inference_cogview2.py
+++ b/examples/cogview2/inference_cogview2.py
@@ -19,7 +19,7 @@ from torchvision import transforms
 from arguments import get_args
 from model.cached_autoregressive_model import CachedAutoregressiveModel
 from model.cuda2d_model import Cuda2dModel
-from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from training import load_checkpoint, initialize_distributed, set_random_seed
 from tokenization import get_tokenizer
 from generation.sampling_strategies import BaseStrategy, IterativeEntfilterStrategy
 from generation.autoregressive_sampling import filling_sequence
@@ -28,7 +28,7 @@ from generation.utils import timed_name, save_multiple_images, generate_continua
 
 def main(args):
     initialize_distributed(args)
-    tokenizer = prepare_tokenizer(args)
+    tokenizer = get_tokenizer(args)
     # build model 
     model = Cuda2dModel(args)
     if args.fp16:
diff --git a/examples/glm/inference_glm.py b/examples/glm/inference_glm.py
index 66c33f4..dd7e231 100644
--- a/examples/glm/inference_glm.py
+++ b/examples/glm/inference_glm.py
@@ -22,7 +22,7 @@ from functools import partial
 from arguments import get_args
 from model.glm_model import GLMModel
 from model.cached_autoregressive_model import CachedAutoregressiveMixin
-from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from training import load_checkpoint, initialize_distributed, set_random_seed
 from generation.autoregressive_sampling import filling_sequence
 from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
 from generation.utils import timed_name, generate_continually
@@ -48,7 +48,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
 def main(args):
     args.do_train = False
     initialize_distributed(args)
-    tokenizer = prepare_tokenizer(args)
+    tokenizer = get_tokenizer(args)
     # build model 
     model = GLMModel(args)
     model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
diff --git a/inference_glm_old.py b/inference_glm_old.py
deleted file mode 100644
index 792efd5..0000000
--- a/inference_glm_old.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# -*- encoding: utf-8 -*-
-'''
-@File    :   inference_cogview.py
-@Time    :   2021/10/09 19:41:58
-@Author  :   Ming Ding
-@Contact :   dm18@mail.tsinghua.edu.cn
-'''
-
-# here put the import lib
-import os
-import sys
-import random
-import time
-from datetime import datetime
-import torch
-import torch.nn.functional as F
-
-import mpu
-from arguments import get_args
-from model.glm_model import GLMModel
-from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
-from generation.glm_sampling import filling_sequence_glm
-from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
-
-
-def read_context(tokenizer, args, output=None):
-    terminate_runs, skip_run = 0, 0
-    if mpu.get_model_parallel_rank() == 0:
-        while True:
-            raw_text = input("\nContext prompt (stop to exit) >>> ")
-            if not raw_text:
-                print('Prompt should not be empty!')
-                continue
-            if raw_text == "stop":
-                terminate_runs = 1
-                break
-            generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
-            if args.block_lm and 'MASK]' not in raw_text:
-                raw_text += ' ' + generation_mask
-            if output is not None:
-                output.write(raw_text)
-            context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
-            if args.block_lm:
-                context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
-                if not raw_text.endswith('MASK]'):
-                    context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
-            context_length = len(context_tokens)
-
-            if context_length >= args.max_sequence_length:
-                print("\nContext length", context_length,
-                      "\nPlease give smaller context than the window length!")
-                continue
-            break
-    else:
-        context_length = 0
-
-    terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
-    torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
-                                group=mpu.get_model_parallel_group())
-    terminate_runs = terminate_runs_tensor[0].item()
-
-    if terminate_runs == 1:
-        return terminate_runs, None, None, None
-
-    context_length_tensor = torch.cuda.LongTensor([context_length])
-
-    torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
-                                group=mpu.get_model_parallel_group())
-    context_length = context_length_tensor[0].item()
-    if mpu.get_model_parallel_rank() == 0:
-        context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
-    else:
-        context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
-    torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
-                                group=mpu.get_model_parallel_group())
-    if mpu.get_model_parallel_rank() != 0:
-        raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
-    return terminate_runs, raw_text, context_tokens_tensor, context_length
-
-
-def get_batch(context_tokens, args):
-    tokens = context_tokens
-    tokens = tokens.view(1, -1).contiguous()
-    tokens = tokens.to('cuda')
-
-    # Get the masks and postition ids.
-    if args.block_lm:
-        attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long)
-        if args.fp16:
-            attention_mask = attention_mask.half()
-        position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long)
-        if not args.no_block_position:
-            block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long)
-            position_ids = torch.stack((position_ids, block_position_ids), dim=0)
-        position_ids = position_ids.unsqueeze(0)
-    else:
-        raise NotImplementedError
-
-    return tokens, attention_mask, position_ids
-
-
-def generate_samples(model, tokenizer, args):
-    model.eval()
-    output_path = "./samples"
-    if not os.path.exists(output_path):
-        os.makedirs(output_path)
-    output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
-    with torch.no_grad(), open(output_path, "w") as output:
-        while True:
-            torch.distributed.barrier(group=mpu.get_model_parallel_group())
-            terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
-            if terminate_runs == 1:
-                return
-            start_time = time.time()
-            if args.block_lm:
-                mems = []
-                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
-                mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
-                mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
-                end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
-                mask_positions = []
-                for token in mask_tokens:
-                    mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
-                mask_positions.sort()
-                if args.no_block_position:
-                    for mask_position in mask_positions:
-                        position_ids[0, mask_position + 1:] += args.out_seq_length
-                _, *mems = model(tokens, position_ids, attention_mask, *mems)
-                for mask_position in mask_positions:
-                    if args.no_block_position:
-                        position = position_ids[0, mask_position].item()
-                    else:
-                        position = mask_position
-                    if args.num_beams > 1:
-                        strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
-                                                      length_penalty=args.length_penalty, end_tokens=end_tokens,
-                                                      no_repeat_ngram_size=args.no_repeat_ngram_size,
-                                                      min_tgt_length=args.min_tgt_length)
-                    else:
-                        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
-                                                end_tokens=end_tokens)
-                    new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems,
-                                                            end_tokens=end_tokens)
-                    tokens = torch.cat((tokens, new_tokens), dim=1)
-            output_tokens_list = tokens.view(-1).contiguous()
-            if mpu.get_model_parallel_rank() == 0:
-                os.system('clear')
-                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
-                print("\nContext:", raw_text, flush=True)
-                decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
-                trim_decode_tokens = decode_tokens
-                print("\nGLM:", trim_decode_tokens, flush=True)
-                output.write(trim_decode_tokens + "\n")
-
-            torch.distributed.barrier(group=mpu.get_model_parallel_group())
-
-
-def main(args):
-    initialize_distributed(args)
-    tokenizer = prepare_tokenizer(args)
-    # build model
-    model = GLMModel(args)
-    if args.fp16:
-        model = model.half()
-    model = model.to(args.device)
-    load_checkpoint(model, args)
-    set_random_seed(args.seed)
-    model.eval()
-    generate_samples(model, tokenizer, args)
-
-
-if __name__ == "__main__":
-    args = get_args()
-
-    with torch.no_grad():
-        main(args)
diff --git a/move_images.py b/move_images.py
deleted file mode 100644
index 1cdd37d..0000000
--- a/move_images.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# %%
-coco_30k = '/workspace/dm/SwissArmyTransformer/coco30k.txt'
-with open(coco_30k, 'r') as fin:
-    lines = fin.readlines()
-    
-import os
-from posixpath import join
-import shutil
-prefix0 = '/workspace/dm/SwissArmyTransformer/coco_samples'
-prefix1 = '/dataset/fd5061f6/mingding/SwissArmyTransformer/coco_samples'
-cnt = 0 
-with open('coco_select.txt', 'w') as fout:
-    for i, line in enumerate(lines):
-        _id, text = line.strip().split('\t')
-        if i % 200 == 0:
-            print(i, cnt)
-        src = os.path.join(prefix1, _id)
-        if not os.path.exists(src):
-            src = os.path.join(prefix0, _id)
-        assert os.path.exists(src), _id
-        fout.write(
-            '\t'.join([text] + [
-                os.path.join(src, f'{i}.jpg')
-                for i in range(60)
-                ]) + '\n'
-        )
-                
-    
\ No newline at end of file
diff --git a/move_weights.py b/move_weights.py
deleted file mode 100644
index 2d73f94..0000000
--- a/move_weights.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# %%
-# import torch
-# old = torch.load('pretrained/cogview/cogview-caption/30000/mp_rank_00_model_states.pt.sat1', map_location='cpu')
-
-# old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight']
-# del old['module']['word_embeddings.weight']
-
-# from model.base_model import BaseModel
-# import argparse
-# import os
-# args = argparse.Namespace(
-#     num_layers=48,
-#     vocab_size=58240,
-#     hidden_size=2560,
-#     num_attention_heads=40,
-#     max_sequence_length=1089,
-#     hidden_dropout=0.1,
-#     attention_dropout=0.1,
-#     checkpoint_activations=True,
-#     checkpoint_num_layers=1,
-#     sandwich_ln=True,
-#     model_parallel_size=1,
-#     world_size=1,
-#     rank=0
-#     )
-# init_method = 'tcp://'
-# master_ip = os.getenv('MASTER_ADDR', 'localhost')
-# master_port = os.getenv('MASTER_PORT', '6000')
-# init_method += master_ip + ':' + master_port
-# torch.distributed.init_process_group(
-#         backend='nccl',
-#         world_size=args.world_size, rank=args.rank,init_method=init_method)
-# import mpu
-#     # Set the model-parallel / data-parallel communicators.
-# mpu.initialize_model_parallel(args.model_parallel_size)
-# print('bg')
-# model = BaseModel(args)
-# # %%
-# missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False)
-# torch.save(old, 'pretrained/cogview/cogview-caption/30000/mp_rank_00_model_states.pt')
-
-
-
-# %%
-import torch
-old = torch.load('/dataset/fd5061f6/english_data/checkpoints/blocklm-10b-1024/126000/mp_rank_00_model_states.pt', map_location='cpu')
-# old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight']
-# del old['module']['word_embeddings.weight']
-#%%
-import torch
-
-from model.cuda2d_model import Cuda2dModel
-import argparse
-import os
-args = argparse.Namespace(
-    num_layers=48,
-    vocab_size=58240,
-    hidden_size=2560,
-    num_attention_heads=40,
-    max_sequence_length=1089,
-    hidden_dropout=0.1,
-    attention_dropout=0.1,
-    checkpoint_activations=True,
-    checkpoint_num_layers=1,
-    sandwich_ln=True,
-    model_parallel_size=1,
-    world_size=1,
-    rank=0,
-    new_sequence_length=1089+4096,
-    layout='0,64,1088,5184',
-    kernel_size=9,
-    kernel_size2=7
-    )
-
-init_method = 'tcp://'
-master_ip = os.getenv('MASTER_ADDR', 'localhost')
-master_port = os.getenv('MASTER_PORT', '6000')
-init_method += master_ip + ':' + master_port
-torch.distributed.init_process_group(
-        backend='nccl',
-        world_size=args.world_size, rank=args.rank,init_method=init_method)
-import mpu
-    # Set the model-parallel / data-parallel communicators.
-mpu.initialize_model_parallel(args.model_parallel_size)
-print('bg')
-#%%
-model = Cuda2dModel(args)
-
-#%%
-oldm = old['module']
-for k in list(oldm.keys()):
-    if k.startswith('mixins.0'):
-        new_k = k.replace('mixins.0', 'mixins.extra_position_embedding')
-    elif k.startswith('mixins.1'):
-        new_k = k.replace('mixins.1', 'mixins.attention_plus')
-    else:
-        continue
-    oldm[new_k] = oldm[k]
-    del oldm[k]
-
-#%%
-old['module']['mixins.0.position_embeddings.weight'] = old['module']['transformer.position_embeddings_plus.weight']
-del old['module']['transformer.position_embeddings_plus.weight']
-
-for i in range(48):
-    old['module'][f'mixins.1.query_key_value.{i}.weight'] = \
-        old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight']
-    del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight']
-    old['module'][f'mixins.1.query_key_value.{i}.bias'] = \
-        old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias']
-    del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias']
-    old['module'][f'mixins.1.dense.{i}.weight'] = \
-        old['module'][f'transformer.layers.{i}.attention.dense_plus.weight']
-    del old['module'][f'transformer.layers.{i}.attention.dense_plus.weight']
-    old['module'][f'mixins.1.dense.{i}.bias'] = \
-        old['module'][f'transformer.layers.{i}.attention.dense_plus.bias']
-    del old['module'][f'transformer.layers.{i}.attention.dense_plus.bias']
-# %%
-missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False)
-
-# %%
-torch.save(old, 'pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt')
-# # %%
-# import torch
-# old = torch.load("/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt", map_location='cpu')
-
-# # %%
-# from collections import OrderedDict
-# new_ckpt = OrderedDict()
-# for k,v in old['state_dict'].items():
-#     new_ckpt[k] = v.detach()
-# torch.save(new_ckpt, 'pretrained/vqvae/l1+ms-ssim+revd_percep.pt')
-# # %%
-
-# %%
-
-old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight']
-del old['module']['word_embeddings.weight']
-#%%
-import torch
-
-from model.glm_model import GLMModel
-import argparse
-import os
-args = argparse.Namespace(
-    num_layers=48,
-    vocab_size=50304,
-    hidden_size=4096,
-    num_attention_heads=64,
-    max_sequence_length=1025,
-    hidden_dropout=0.1,
-    attention_dropout=0.1,
-    checkpoint_activations=True,
-    checkpoint_num_layers=1,
-    sandwich_ln=False,
-    model_parallel_size=1,
-    world_size=1,
-    rank=0
-    )
-
-init_method = 'tcp://'
-master_ip = os.getenv('MASTER_ADDR', 'localhost')
-master_port = os.getenv('MASTER_PORT', '6000')
-init_method += master_ip + ':' + master_port
-torch.distributed.init_process_group(
-        backend='nccl',
-        world_size=args.world_size, rank=args.rank,init_method=init_method)
-import mpu
-    # Set the model-parallel / data-parallel communicators.
-mpu.initialize_model_parallel(args.model_parallel_size)
-print('bg')
-# %%
-model = GLMModel(args)
-# %%
-old['module']['mixins.block_position_embedding.block_position_embeddings.weight'] = old['module']['transformer.block_position_embeddings.weight']
-del old['module']['transformer.block_position_embeddings.weight']
-# %%
-missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=True)
-
-# %%
-import os
-os.makedirs('pretrained/glm/glm-en-10b/250000', exist_ok=True)
-torch.save(old, 'pretrained/glm/glm-en-10b/250000/mp_rank_00_model_states.pt')
-
-# %%
diff --git a/pretrained/vqvae/placeholder b/pretrained/vqvae/placeholder
deleted file mode 100644
index e69de29..0000000
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..84ec277
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,33 @@
+
+# Copyright (c) Ming Ding, et al. in KEG, Tsinghua University.
+#
+# LICENSE file in the root directory of this source tree.
+
+import json
+import sys
+import os
+from pathlib import Path
+
+from setuptools import find_packages, setup
+
+
+def _requirements():
+    return Path("requirements.txt").read_text()
+
+setup(
+    name="SwissArmyTransformer",
+    version=0.1,
+    description="A transformer-based framework with finetuning as the first class citizen.",
+    long_description=Path("README.md").read_text(),
+    long_description_content_type="text/markdown",
+    install_requires=_requirements(),
+    entry_points={},
+    packages=find_packages(),
+    url="https://github.com/THUDM/SwissArmyTransformer",
+    author="Ming Ding, et al.",
+    author_email="dm_thu@qq.com",
+    scripts={},
+    include_package_data=True,
+    python_requires=">=3.5",
+    license="Apache 2.0 license"
+)
\ No newline at end of file
-- 
GitLab