From 78b7ca5f0f17e8bacad40bb1ceccc893a2e19204 Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Fri, 29 Oct 2021 19:14:57 +0000 Subject: [PATCH] pass cogview generate --- .gitignore | 4 +- CHANGE_LOG.md | 27 +++ MANIFEST.in | 3 + pretrained/cogview/placeholder => README.md | 0 SwissArmyTransformer/__init__.py | 1 - SwissArmyTransformer/tokenization/__init__.py | 2 +- .../tokenization/cogview/vqvae/api.py | 10 +- SwissArmyTransformer/training/__init__.py | 2 +- SwissArmyTransformer/training/model_io.py | 2 +- examples/cogview/inference_cogview.py | 2 +- examples/cogview/inference_cogview_caps.py | 4 +- .../cogview/scripts/text2image_cogview.sh | 4 +- examples/cogview2/inference_cogview2.py | 4 +- examples/glm/inference_glm.py | 4 +- inference_glm_old.py | 176 ----------------- move_images.py | 28 --- move_weights.py | 185 ------------------ pretrained/vqvae/placeholder | 0 setup.py | 33 ++++ 19 files changed, 83 insertions(+), 408 deletions(-) create mode 100644 MANIFEST.in rename pretrained/cogview/placeholder => README.md (100%) delete mode 100644 inference_glm_old.py delete mode 100644 move_images.py delete mode 100644 move_weights.py delete mode 100644 pretrained/vqvae/placeholder create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index a87af0c..d129f41 100755 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,6 @@ input*.txt coco_scores/* checkpoints/ *coco* -runs \ No newline at end of file +runs +dist/ +*.egg-info \ No newline at end of file diff --git a/CHANGE_LOG.md b/CHANGE_LOG.md index ad82b80..67cdf6f 100644 --- a/CHANGE_LOG.md +++ b/CHANGE_LOG.md @@ -1 +1,28 @@ # 2021.10.29 +1. change `mixins` from `ModuleList` to `ModuleDict` +2. return tokens and mems in `fill_sequence`, and mems becomes a tensor. +3. `CachedAutoRegressiveMixin` +## How to migrate old SAT ckpt to new version? +Example: +```python +import torch +old = torch.load('xxxxx/mp_rank_00_model_states.pt.old', map_location='cpu') + +# replace names, mixins index to keys +oldm = old['module'] +for k in list(oldm.keys()): + if k.startswith('mixins.0'): + new_k = k.replace('mixins.0', 'mixins.extra_position_embedding') + elif k.startswith('mixins.1'): + new_k = k.replace('mixins.1', 'mixins.attention_plus') + else: + continue + oldm[new_k] = oldm[k] + del oldm[k] +# save to destination +torch.save(old, 'xxxxx/mp_rank_00_model_states.pt') + +``` + + + diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..db51f02 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include requirements.txt +global-exclude __pycache__/* +graft SwissArmyTransformer/tokenization/embed_assets \ No newline at end of file diff --git a/pretrained/cogview/placeholder b/README.md similarity index 100% rename from pretrained/cogview/placeholder rename to README.md diff --git a/SwissArmyTransformer/__init__.py b/SwissArmyTransformer/__init__.py index 6e1974e..dd0d987 100644 --- a/SwissArmyTransformer/__init__.py +++ b/SwissArmyTransformer/__init__.py @@ -1,4 +1,3 @@ -__version__ = '0.1' from .arguments import get_args from .training import load_checkpoint, set_random_seed, initialize_distributed from .tokenization import get_tokenizer diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py index 422c03e..64f9e36 100644 --- a/SwissArmyTransformer/tokenization/__init__.py +++ b/SwissArmyTransformer/tokenization/__init__.py @@ -29,7 +29,7 @@ def _export_vocab_size_to_args(args, original_num_tokens): 'tokens (new size: {})'.format( before, after - before, after)) args.vocab_size = after - print_rank_0("prepare tokenizer done", flush=True) + print_rank_0("prepare tokenizer done") return tokenizer def get_tokenizer(args=None, outer_tokenizer=None): diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/api.py b/SwissArmyTransformer/tokenization/cogview/vqvae/api.py index 7d2fd03..060760a 100755 --- a/SwissArmyTransformer/tokenization/cogview/vqvae/api.py +++ b/SwissArmyTransformer/tokenization/cogview/vqvae/api.py @@ -26,7 +26,7 @@ def new_module(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") module, cls = config.get("target").rsplit(".", 1) - model = getattr(importlib.import_module(module, package=None), cls)(**config.get("params", dict())) + model = getattr(importlib.import_module(module, package=__package__), cls)(**config.get("params", dict())) device = config.get("device", "cpu") model = model.to(device) @@ -45,7 +45,7 @@ def new_module(config): def load_decoder_default(device=0, path="pretrained/vqvae/l1+ms-ssim+revd_percep.pt"): # exp: load currently best decoder - target = "vqvae.vqvae_diffusion.Decoder" + target = ".vqvae_diffusion.Decoder" params = { "double_z": False, "z_channels": 256, @@ -100,7 +100,7 @@ def load_model_default(device=0, } config = { - 'target': "vqvae.vqvae_zc.VQVAE", + 'target': ".vqvae_zc.VQVAE", 'params': params, 'ckpt': path, 'device': device @@ -116,7 +116,7 @@ def test_decode(configs, testcase, device=0, output_path=None): output_path = os.path.join("sample", f"{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg") quantize_config = { - "target": "vqvae.vqvae_zc.Quantize", + "target": ".vqvae_zc.Quantize", "params": { "dim": 256, "n_embed": 8192, @@ -149,7 +149,7 @@ def test_decode_default(device=0): # testing 3 decoders: original/l1+ms-ssim/l1+ms-ssim+perceptual configs = [ { - "target": "vqvae.vqvae_zc.Decoder", + "target": ".vqvae_zc.Decoder", "params": { "in_channel": 256, "out_channel": 3, diff --git a/SwissArmyTransformer/training/__init__.py b/SwissArmyTransformer/training/__init__.py index dc0337f..4d462e6 100644 --- a/SwissArmyTransformer/training/__init__.py +++ b/SwissArmyTransformer/training/__init__.py @@ -1,2 +1,2 @@ -from .deepspeed_training import initialize_distributed, set_random_seed, prepare_tokenizer +from .deepspeed_training import initialize_distributed, set_random_seed from .model_io import load_checkpoint \ No newline at end of file diff --git a/SwissArmyTransformer/training/model_io.py b/SwissArmyTransformer/training/model_io.py index 92becb9..d66e40e 100644 --- a/SwissArmyTransformer/training/model_io.py +++ b/SwissArmyTransformer/training/model_io.py @@ -14,7 +14,7 @@ import random import torch import numpy as np -import SwissArmyTransformer.mpu +from SwissArmyTransformer import mpu from .utils import print_rank_0 diff --git a/examples/cogview/inference_cogview.py b/examples/cogview/inference_cogview.py index 8d2a555..f546e37 100644 --- a/examples/cogview/inference_cogview.py +++ b/examples/cogview/inference_cogview.py @@ -62,7 +62,7 @@ def main(args): batch_size=min(args.batch_size, mbz), strategy=strategy, log_attention_weights=log_attention_weights - ) + )[0] ) output_tokens = torch.cat(output_list, dim=0) # decoding diff --git a/examples/cogview/inference_cogview_caps.py b/examples/cogview/inference_cogview_caps.py index 7b2e4e7..db6a35e 100644 --- a/examples/cogview/inference_cogview_caps.py +++ b/examples/cogview/inference_cogview_caps.py @@ -16,13 +16,13 @@ import argparse from arguments import get_args from model.base_model import BaseModel -from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer +from training import load_checkpoint, initialize_distributed, set_random_seed from generation.autoregressive_sampling import get_masks_and_position_ids from generation.utils import timed_name, save_multiple_images, generate_continually def main(args): initialize_distributed(args) - tokenizer = prepare_tokenizer(args) + tokenizer = get_tokenizer(args) # build model model = BaseModel(args) if args.fp16: diff --git a/examples/cogview/scripts/text2image_cogview.sh b/examples/cogview/scripts/text2image_cogview.sh index bcb1ecd..9bb2213 100755 --- a/examples/cogview/scripts/text2image_cogview.sh +++ b/examples/cogview/scripts/text2image_cogview.sh @@ -1,6 +1,6 @@ #!/bin/bash -CHECKPOINT_PATH=pretrained/cogview/cogview-base +CHECKPOINT_PATH=/workspace/dm/SwissArmyTransformer/pretrained/cogview/cogview-base NLAYERS=48 NHIDDEN=2560 NATT=40 @@ -17,7 +17,7 @@ script_dir=$(dirname $script_path) MASTER_PORT=${MASTER_PORT} python inference_cogview.py \ --tokenizer-type cogview \ - --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \ + --img-tokenizer-path /workspace/dm/SwissArmyTransformer/pretrained/vqvae/l1+ms-ssim+revd_percep.pt \ --mode inference \ --distributed-backend nccl \ --max-sequence-length 1089 \ diff --git a/examples/cogview2/inference_cogview2.py b/examples/cogview2/inference_cogview2.py index c69e589..b3fb7e0 100644 --- a/examples/cogview2/inference_cogview2.py +++ b/examples/cogview2/inference_cogview2.py @@ -19,7 +19,7 @@ from torchvision import transforms from arguments import get_args from model.cached_autoregressive_model import CachedAutoregressiveModel from model.cuda2d_model import Cuda2dModel -from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer +from training import load_checkpoint, initialize_distributed, set_random_seed from tokenization import get_tokenizer from generation.sampling_strategies import BaseStrategy, IterativeEntfilterStrategy from generation.autoregressive_sampling import filling_sequence @@ -28,7 +28,7 @@ from generation.utils import timed_name, save_multiple_images, generate_continua def main(args): initialize_distributed(args) - tokenizer = prepare_tokenizer(args) + tokenizer = get_tokenizer(args) # build model model = Cuda2dModel(args) if args.fp16: diff --git a/examples/glm/inference_glm.py b/examples/glm/inference_glm.py index 66c33f4..dd7e231 100644 --- a/examples/glm/inference_glm.py +++ b/examples/glm/inference_glm.py @@ -22,7 +22,7 @@ from functools import partial from arguments import get_args from model.glm_model import GLMModel from model.cached_autoregressive_model import CachedAutoregressiveMixin -from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer +from training import load_checkpoint, initialize_distributed, set_random_seed from generation.autoregressive_sampling import filling_sequence from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy from generation.utils import timed_name, generate_continually @@ -48,7 +48,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length): def main(args): args.do_train = False initialize_distributed(args) - tokenizer = prepare_tokenizer(args) + tokenizer = get_tokenizer(args) # build model model = GLMModel(args) model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) diff --git a/inference_glm_old.py b/inference_glm_old.py deleted file mode 100644 index 792efd5..0000000 --- a/inference_glm_old.py +++ /dev/null @@ -1,176 +0,0 @@ -# -*- encoding: utf-8 -*- -''' -@File : inference_cogview.py -@Time : 2021/10/09 19:41:58 -@Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn -''' - -# here put the import lib -import os -import sys -import random -import time -from datetime import datetime -import torch -import torch.nn.functional as F - -import mpu -from arguments import get_args -from model.glm_model import GLMModel -from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer -from generation.glm_sampling import filling_sequence_glm -from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy - - -def read_context(tokenizer, args, output=None): - terminate_runs, skip_run = 0, 0 - if mpu.get_model_parallel_rank() == 0: - while True: - raw_text = input("\nContext prompt (stop to exit) >>> ") - if not raw_text: - print('Prompt should not be empty!') - continue - if raw_text == "stop": - terminate_runs = 1 - break - generation_mask = '[gMASK]' if args.task_mask else '[MASK]' - if args.block_lm and 'MASK]' not in raw_text: - raw_text += ' ' + generation_mask - if output is not None: - output.write(raw_text) - context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization - if args.block_lm: - context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens - if not raw_text.endswith('MASK]'): - context_tokens = context_tokens + [tokenizer.get_command('eos').Id] - context_length = len(context_tokens) - - if context_length >= args.max_sequence_length: - print("\nContext length", context_length, - "\nPlease give smaller context than the window length!") - continue - break - else: - context_length = 0 - - terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) - torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group()) - terminate_runs = terminate_runs_tensor[0].item() - - if terminate_runs == 1: - return terminate_runs, None, None, None - - context_length_tensor = torch.cuda.LongTensor([context_length]) - - torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group()) - context_length = context_length_tensor[0].item() - if mpu.get_model_parallel_rank() == 0: - context_tokens_tensor = torch.cuda.LongTensor(context_tokens) - else: - context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) - torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group()) - if mpu.get_model_parallel_rank() != 0: - raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) - return terminate_runs, raw_text, context_tokens_tensor, context_length - - -def get_batch(context_tokens, args): - tokens = context_tokens - tokens = tokens.view(1, -1).contiguous() - tokens = tokens.to('cuda') - - # Get the masks and postition ids. - if args.block_lm: - attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long) - if args.fp16: - attention_mask = attention_mask.half() - position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long) - if not args.no_block_position: - block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long) - position_ids = torch.stack((position_ids, block_position_ids), dim=0) - position_ids = position_ids.unsqueeze(0) - else: - raise NotImplementedError - - return tokens, attention_mask, position_ids - - -def generate_samples(model, tokenizer, args): - model.eval() - output_path = "./samples" - if not os.path.exists(output_path): - os.makedirs(output_path) - output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt") - with torch.no_grad(), open(output_path, "w") as output: - while True: - torch.distributed.barrier(group=mpu.get_model_parallel_group()) - terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output) - if terminate_runs == 1: - return - start_time = time.time() - if args.block_lm: - mems = [] - tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) - mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] - mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] - end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id] - mask_positions = [] - for token in mask_tokens: - mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() - mask_positions.sort() - if args.no_block_position: - for mask_position in mask_positions: - position_ids[0, mask_position + 1:] += args.out_seq_length - _, *mems = model(tokens, position_ids, attention_mask, *mems) - for mask_position in mask_positions: - if args.no_block_position: - position = position_ids[0, mask_position].item() - else: - position = mask_position - if args.num_beams > 1: - strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length, - length_penalty=args.length_penalty, end_tokens=end_tokens, - no_repeat_ngram_size=args.no_repeat_ngram_size, - min_tgt_length=args.min_tgt_length) - else: - strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, - end_tokens=end_tokens) - new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems, - end_tokens=end_tokens) - tokens = torch.cat((tokens, new_tokens), dim=1) - output_tokens_list = tokens.view(-1).contiguous() - if mpu.get_model_parallel_rank() == 0: - os.system('clear') - print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) - print("\nContext:", raw_text, flush=True) - decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) - trim_decode_tokens = decode_tokens - print("\nGLM:", trim_decode_tokens, flush=True) - output.write(trim_decode_tokens + "\n") - - torch.distributed.barrier(group=mpu.get_model_parallel_group()) - - -def main(args): - initialize_distributed(args) - tokenizer = prepare_tokenizer(args) - # build model - model = GLMModel(args) - if args.fp16: - model = model.half() - model = model.to(args.device) - load_checkpoint(model, args) - set_random_seed(args.seed) - model.eval() - generate_samples(model, tokenizer, args) - - -if __name__ == "__main__": - args = get_args() - - with torch.no_grad(): - main(args) diff --git a/move_images.py b/move_images.py deleted file mode 100644 index 1cdd37d..0000000 --- a/move_images.py +++ /dev/null @@ -1,28 +0,0 @@ -# %% -coco_30k = '/workspace/dm/SwissArmyTransformer/coco30k.txt' -with open(coco_30k, 'r') as fin: - lines = fin.readlines() - -import os -from posixpath import join -import shutil -prefix0 = '/workspace/dm/SwissArmyTransformer/coco_samples' -prefix1 = '/dataset/fd5061f6/mingding/SwissArmyTransformer/coco_samples' -cnt = 0 -with open('coco_select.txt', 'w') as fout: - for i, line in enumerate(lines): - _id, text = line.strip().split('\t') - if i % 200 == 0: - print(i, cnt) - src = os.path.join(prefix1, _id) - if not os.path.exists(src): - src = os.path.join(prefix0, _id) - assert os.path.exists(src), _id - fout.write( - '\t'.join([text] + [ - os.path.join(src, f'{i}.jpg') - for i in range(60) - ]) + '\n' - ) - - \ No newline at end of file diff --git a/move_weights.py b/move_weights.py deleted file mode 100644 index 2d73f94..0000000 --- a/move_weights.py +++ /dev/null @@ -1,185 +0,0 @@ -# %% -# import torch -# old = torch.load('pretrained/cogview/cogview-caption/30000/mp_rank_00_model_states.pt.sat1', map_location='cpu') - -# old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight'] -# del old['module']['word_embeddings.weight'] - -# from model.base_model import BaseModel -# import argparse -# import os -# args = argparse.Namespace( -# num_layers=48, -# vocab_size=58240, -# hidden_size=2560, -# num_attention_heads=40, -# max_sequence_length=1089, -# hidden_dropout=0.1, -# attention_dropout=0.1, -# checkpoint_activations=True, -# checkpoint_num_layers=1, -# sandwich_ln=True, -# model_parallel_size=1, -# world_size=1, -# rank=0 -# ) -# init_method = 'tcp://' -# master_ip = os.getenv('MASTER_ADDR', 'localhost') -# master_port = os.getenv('MASTER_PORT', '6000') -# init_method += master_ip + ':' + master_port -# torch.distributed.init_process_group( -# backend='nccl', -# world_size=args.world_size, rank=args.rank,init_method=init_method) -# import mpu -# # Set the model-parallel / data-parallel communicators. -# mpu.initialize_model_parallel(args.model_parallel_size) -# print('bg') -# model = BaseModel(args) -# # %% -# missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False) -# torch.save(old, 'pretrained/cogview/cogview-caption/30000/mp_rank_00_model_states.pt') - - - -# %% -import torch -old = torch.load('/dataset/fd5061f6/english_data/checkpoints/blocklm-10b-1024/126000/mp_rank_00_model_states.pt', map_location='cpu') -# old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight'] -# del old['module']['word_embeddings.weight'] -#%% -import torch - -from model.cuda2d_model import Cuda2dModel -import argparse -import os -args = argparse.Namespace( - num_layers=48, - vocab_size=58240, - hidden_size=2560, - num_attention_heads=40, - max_sequence_length=1089, - hidden_dropout=0.1, - attention_dropout=0.1, - checkpoint_activations=True, - checkpoint_num_layers=1, - sandwich_ln=True, - model_parallel_size=1, - world_size=1, - rank=0, - new_sequence_length=1089+4096, - layout='0,64,1088,5184', - kernel_size=9, - kernel_size2=7 - ) - -init_method = 'tcp://' -master_ip = os.getenv('MASTER_ADDR', 'localhost') -master_port = os.getenv('MASTER_PORT', '6000') -init_method += master_ip + ':' + master_port -torch.distributed.init_process_group( - backend='nccl', - world_size=args.world_size, rank=args.rank,init_method=init_method) -import mpu - # Set the model-parallel / data-parallel communicators. -mpu.initialize_model_parallel(args.model_parallel_size) -print('bg') -#%% -model = Cuda2dModel(args) - -#%% -oldm = old['module'] -for k in list(oldm.keys()): - if k.startswith('mixins.0'): - new_k = k.replace('mixins.0', 'mixins.extra_position_embedding') - elif k.startswith('mixins.1'): - new_k = k.replace('mixins.1', 'mixins.attention_plus') - else: - continue - oldm[new_k] = oldm[k] - del oldm[k] - -#%% -old['module']['mixins.0.position_embeddings.weight'] = old['module']['transformer.position_embeddings_plus.weight'] -del old['module']['transformer.position_embeddings_plus.weight'] - -for i in range(48): - old['module'][f'mixins.1.query_key_value.{i}.weight'] = \ - old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight'] - del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight'] - old['module'][f'mixins.1.query_key_value.{i}.bias'] = \ - old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias'] - del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias'] - old['module'][f'mixins.1.dense.{i}.weight'] = \ - old['module'][f'transformer.layers.{i}.attention.dense_plus.weight'] - del old['module'][f'transformer.layers.{i}.attention.dense_plus.weight'] - old['module'][f'mixins.1.dense.{i}.bias'] = \ - old['module'][f'transformer.layers.{i}.attention.dense_plus.bias'] - del old['module'][f'transformer.layers.{i}.attention.dense_plus.bias'] -# %% -missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False) - -# %% -torch.save(old, 'pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt') -# # %% -# import torch -# old = torch.load("/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt", map_location='cpu') - -# # %% -# from collections import OrderedDict -# new_ckpt = OrderedDict() -# for k,v in old['state_dict'].items(): -# new_ckpt[k] = v.detach() -# torch.save(new_ckpt, 'pretrained/vqvae/l1+ms-ssim+revd_percep.pt') -# # %% - -# %% - -old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight'] -del old['module']['word_embeddings.weight'] -#%% -import torch - -from model.glm_model import GLMModel -import argparse -import os -args = argparse.Namespace( - num_layers=48, - vocab_size=50304, - hidden_size=4096, - num_attention_heads=64, - max_sequence_length=1025, - hidden_dropout=0.1, - attention_dropout=0.1, - checkpoint_activations=True, - checkpoint_num_layers=1, - sandwich_ln=False, - model_parallel_size=1, - world_size=1, - rank=0 - ) - -init_method = 'tcp://' -master_ip = os.getenv('MASTER_ADDR', 'localhost') -master_port = os.getenv('MASTER_PORT', '6000') -init_method += master_ip + ':' + master_port -torch.distributed.init_process_group( - backend='nccl', - world_size=args.world_size, rank=args.rank,init_method=init_method) -import mpu - # Set the model-parallel / data-parallel communicators. -mpu.initialize_model_parallel(args.model_parallel_size) -print('bg') -# %% -model = GLMModel(args) -# %% -old['module']['mixins.block_position_embedding.block_position_embeddings.weight'] = old['module']['transformer.block_position_embeddings.weight'] -del old['module']['transformer.block_position_embeddings.weight'] -# %% -missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=True) - -# %% -import os -os.makedirs('pretrained/glm/glm-en-10b/250000', exist_ok=True) -torch.save(old, 'pretrained/glm/glm-en-10b/250000/mp_rank_00_model_states.pt') - -# %% diff --git a/pretrained/vqvae/placeholder b/pretrained/vqvae/placeholder deleted file mode 100644 index e69de29..0000000 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..84ec277 --- /dev/null +++ b/setup.py @@ -0,0 +1,33 @@ + +# Copyright (c) Ming Ding, et al. in KEG, Tsinghua University. +# +# LICENSE file in the root directory of this source tree. + +import json +import sys +import os +from pathlib import Path + +from setuptools import find_packages, setup + + +def _requirements(): + return Path("requirements.txt").read_text() + +setup( + name="SwissArmyTransformer", + version=0.1, + description="A transformer-based framework with finetuning as the first class citizen.", + long_description=Path("README.md").read_text(), + long_description_content_type="text/markdown", + install_requires=_requirements(), + entry_points={}, + packages=find_packages(), + url="https://github.com/THUDM/SwissArmyTransformer", + author="Ming Ding, et al.", + author_email="dm_thu@qq.com", + scripts={}, + include_package_data=True, + python_requires=">=3.5", + license="Apache 2.0 license" +) \ No newline at end of file -- GitLab