Skip to content
Snippets Groups Projects
Commit 78b7ca5f authored by Ming Ding's avatar Ming Ding
Browse files

pass cogview generate

parent cb3af696
No related branches found
No related tags found
No related merge requests found
Showing
with 83 additions and 408 deletions
...@@ -18,4 +18,6 @@ input*.txt ...@@ -18,4 +18,6 @@ input*.txt
coco_scores/* coco_scores/*
checkpoints/ checkpoints/
*coco* *coco*
runs runs
\ No newline at end of file dist/
*.egg-info
\ No newline at end of file
# 2021.10.29 # 2021.10.29
1. change `mixins` from `ModuleList` to `ModuleDict`
2. return tokens and mems in `fill_sequence`, and mems becomes a tensor.
3. `CachedAutoRegressiveMixin`
## How to migrate old SAT ckpt to new version?
Example:
```python
import torch
old = torch.load('xxxxx/mp_rank_00_model_states.pt.old', map_location='cpu')
# replace names, mixins index to keys
oldm = old['module']
for k in list(oldm.keys()):
if k.startswith('mixins.0'):
new_k = k.replace('mixins.0', 'mixins.extra_position_embedding')
elif k.startswith('mixins.1'):
new_k = k.replace('mixins.1', 'mixins.attention_plus')
else:
continue
oldm[new_k] = oldm[k]
del oldm[k]
# save to destination
torch.save(old, 'xxxxx/mp_rank_00_model_states.pt')
```
include requirements.txt
global-exclude __pycache__/*
graft SwissArmyTransformer/tokenization/embed_assets
\ No newline at end of file
File moved
__version__ = '0.1'
from .arguments import get_args from .arguments import get_args
from .training import load_checkpoint, set_random_seed, initialize_distributed from .training import load_checkpoint, set_random_seed, initialize_distributed
from .tokenization import get_tokenizer from .tokenization import get_tokenizer
......
...@@ -29,7 +29,7 @@ def _export_vocab_size_to_args(args, original_num_tokens): ...@@ -29,7 +29,7 @@ def _export_vocab_size_to_args(args, original_num_tokens):
'tokens (new size: {})'.format( 'tokens (new size: {})'.format(
before, after - before, after)) before, after - before, after))
args.vocab_size = after args.vocab_size = after
print_rank_0("prepare tokenizer done", flush=True) print_rank_0("prepare tokenizer done")
return tokenizer return tokenizer
def get_tokenizer(args=None, outer_tokenizer=None): def get_tokenizer(args=None, outer_tokenizer=None):
......
...@@ -26,7 +26,7 @@ def new_module(config): ...@@ -26,7 +26,7 @@ def new_module(config):
if not "target" in config: if not "target" in config:
raise KeyError("Expected key `target` to instantiate.") raise KeyError("Expected key `target` to instantiate.")
module, cls = config.get("target").rsplit(".", 1) module, cls = config.get("target").rsplit(".", 1)
model = getattr(importlib.import_module(module, package=None), cls)(**config.get("params", dict())) model = getattr(importlib.import_module(module, package=__package__), cls)(**config.get("params", dict()))
device = config.get("device", "cpu") device = config.get("device", "cpu")
model = model.to(device) model = model.to(device)
...@@ -45,7 +45,7 @@ def new_module(config): ...@@ -45,7 +45,7 @@ def new_module(config):
def load_decoder_default(device=0, path="pretrained/vqvae/l1+ms-ssim+revd_percep.pt"): def load_decoder_default(device=0, path="pretrained/vqvae/l1+ms-ssim+revd_percep.pt"):
# exp: load currently best decoder # exp: load currently best decoder
target = "vqvae.vqvae_diffusion.Decoder" target = ".vqvae_diffusion.Decoder"
params = { params = {
"double_z": False, "double_z": False,
"z_channels": 256, "z_channels": 256,
...@@ -100,7 +100,7 @@ def load_model_default(device=0, ...@@ -100,7 +100,7 @@ def load_model_default(device=0,
} }
config = { config = {
'target': "vqvae.vqvae_zc.VQVAE", 'target': ".vqvae_zc.VQVAE",
'params': params, 'params': params,
'ckpt': path, 'ckpt': path,
'device': device 'device': device
...@@ -116,7 +116,7 @@ def test_decode(configs, testcase, device=0, output_path=None): ...@@ -116,7 +116,7 @@ def test_decode(configs, testcase, device=0, output_path=None):
output_path = os.path.join("sample", f"{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg") output_path = os.path.join("sample", f"{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg")
quantize_config = { quantize_config = {
"target": "vqvae.vqvae_zc.Quantize", "target": ".vqvae_zc.Quantize",
"params": { "params": {
"dim": 256, "dim": 256,
"n_embed": 8192, "n_embed": 8192,
...@@ -149,7 +149,7 @@ def test_decode_default(device=0): ...@@ -149,7 +149,7 @@ def test_decode_default(device=0):
# testing 3 decoders: original/l1+ms-ssim/l1+ms-ssim+perceptual # testing 3 decoders: original/l1+ms-ssim/l1+ms-ssim+perceptual
configs = [ configs = [
{ {
"target": "vqvae.vqvae_zc.Decoder", "target": ".vqvae_zc.Decoder",
"params": { "params": {
"in_channel": 256, "in_channel": 256,
"out_channel": 3, "out_channel": 3,
......
from .deepspeed_training import initialize_distributed, set_random_seed, prepare_tokenizer from .deepspeed_training import initialize_distributed, set_random_seed
from .model_io import load_checkpoint from .model_io import load_checkpoint
\ No newline at end of file
...@@ -14,7 +14,7 @@ import random ...@@ -14,7 +14,7 @@ import random
import torch import torch
import numpy as np import numpy as np
import SwissArmyTransformer.mpu from SwissArmyTransformer import mpu
from .utils import print_rank_0 from .utils import print_rank_0
......
...@@ -62,7 +62,7 @@ def main(args): ...@@ -62,7 +62,7 @@ def main(args):
batch_size=min(args.batch_size, mbz), batch_size=min(args.batch_size, mbz),
strategy=strategy, strategy=strategy,
log_attention_weights=log_attention_weights log_attention_weights=log_attention_weights
) )[0]
) )
output_tokens = torch.cat(output_list, dim=0) output_tokens = torch.cat(output_list, dim=0)
# decoding # decoding
......
...@@ -16,13 +16,13 @@ import argparse ...@@ -16,13 +16,13 @@ import argparse
from arguments import get_args from arguments import get_args
from model.base_model import BaseModel from model.base_model import BaseModel
from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer from training import load_checkpoint, initialize_distributed, set_random_seed
from generation.autoregressive_sampling import get_masks_and_position_ids from generation.autoregressive_sampling import get_masks_and_position_ids
from generation.utils import timed_name, save_multiple_images, generate_continually from generation.utils import timed_name, save_multiple_images, generate_continually
def main(args): def main(args):
initialize_distributed(args) initialize_distributed(args)
tokenizer = prepare_tokenizer(args) tokenizer = get_tokenizer(args)
# build model # build model
model = BaseModel(args) model = BaseModel(args)
if args.fp16: if args.fp16:
......
#!/bin/bash #!/bin/bash
CHECKPOINT_PATH=pretrained/cogview/cogview-base CHECKPOINT_PATH=/workspace/dm/SwissArmyTransformer/pretrained/cogview/cogview-base
NLAYERS=48 NLAYERS=48
NHIDDEN=2560 NHIDDEN=2560
NATT=40 NATT=40
...@@ -17,7 +17,7 @@ script_dir=$(dirname $script_path) ...@@ -17,7 +17,7 @@ script_dir=$(dirname $script_path)
MASTER_PORT=${MASTER_PORT} python inference_cogview.py \ MASTER_PORT=${MASTER_PORT} python inference_cogview.py \
--tokenizer-type cogview \ --tokenizer-type cogview \
--img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \ --img-tokenizer-path /workspace/dm/SwissArmyTransformer/pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
--mode inference \ --mode inference \
--distributed-backend nccl \ --distributed-backend nccl \
--max-sequence-length 1089 \ --max-sequence-length 1089 \
......
...@@ -19,7 +19,7 @@ from torchvision import transforms ...@@ -19,7 +19,7 @@ from torchvision import transforms
from arguments import get_args from arguments import get_args
from model.cached_autoregressive_model import CachedAutoregressiveModel from model.cached_autoregressive_model import CachedAutoregressiveModel
from model.cuda2d_model import Cuda2dModel from model.cuda2d_model import Cuda2dModel
from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer from training import load_checkpoint, initialize_distributed, set_random_seed
from tokenization import get_tokenizer from tokenization import get_tokenizer
from generation.sampling_strategies import BaseStrategy, IterativeEntfilterStrategy from generation.sampling_strategies import BaseStrategy, IterativeEntfilterStrategy
from generation.autoregressive_sampling import filling_sequence from generation.autoregressive_sampling import filling_sequence
...@@ -28,7 +28,7 @@ from generation.utils import timed_name, save_multiple_images, generate_continua ...@@ -28,7 +28,7 @@ from generation.utils import timed_name, save_multiple_images, generate_continua
def main(args): def main(args):
initialize_distributed(args) initialize_distributed(args)
tokenizer = prepare_tokenizer(args) tokenizer = get_tokenizer(args)
# build model # build model
model = Cuda2dModel(args) model = Cuda2dModel(args)
if args.fp16: if args.fp16:
......
...@@ -22,7 +22,7 @@ from functools import partial ...@@ -22,7 +22,7 @@ from functools import partial
from arguments import get_args from arguments import get_args
from model.glm_model import GLMModel from model.glm_model import GLMModel
from model.cached_autoregressive_model import CachedAutoregressiveMixin from model.cached_autoregressive_model import CachedAutoregressiveMixin
from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer from training import load_checkpoint, initialize_distributed, set_random_seed
from generation.autoregressive_sampling import filling_sequence from generation.autoregressive_sampling import filling_sequence
from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
from generation.utils import timed_name, generate_continually from generation.utils import timed_name, generate_continually
...@@ -48,7 +48,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length): ...@@ -48,7 +48,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
def main(args): def main(args):
args.do_train = False args.do_train = False
initialize_distributed(args) initialize_distributed(args)
tokenizer = prepare_tokenizer(args) tokenizer = get_tokenizer(args)
# build model # build model
model = GLMModel(args) model = GLMModel(args)
model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
......
# -*- encoding: utf-8 -*-
'''
@File : inference_cogview.py
@Time : 2021/10/09 19:41:58
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import random
import time
from datetime import datetime
import torch
import torch.nn.functional as F
import mpu
from arguments import get_args
from model.glm_model import GLMModel
from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
from generation.glm_sampling import filling_sequence_glm
from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
def read_context(tokenizer, args, output=None):
terminate_runs, skip_run = 0, 0
if mpu.get_model_parallel_rank() == 0:
while True:
raw_text = input("\nContext prompt (stop to exit) >>> ")
if not raw_text:
print('Prompt should not be empty!')
continue
if raw_text == "stop":
terminate_runs = 1
break
generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
if args.block_lm and 'MASK]' not in raw_text:
raw_text += ' ' + generation_mask
if output is not None:
output.write(raw_text)
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
if args.block_lm:
context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
if not raw_text.endswith('MASK]'):
context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
context_length = len(context_tokens)
if context_length >= args.max_sequence_length:
print("\nContext length", context_length,
"\nPlease give smaller context than the window length!")
continue
break
else:
context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return terminate_runs, None, None, None
context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor[0].item()
if mpu.get_model_parallel_rank() == 0:
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
else:
context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
if mpu.get_model_parallel_rank() != 0:
raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
return terminate_runs, raw_text, context_tokens_tensor, context_length
def get_batch(context_tokens, args):
tokens = context_tokens
tokens = tokens.view(1, -1).contiguous()
tokens = tokens.to('cuda')
# Get the masks and postition ids.
if args.block_lm:
attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long)
if args.fp16:
attention_mask = attention_mask.half()
position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long)
if not args.no_block_position:
block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long)
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
position_ids = position_ids.unsqueeze(0)
else:
raise NotImplementedError
return tokens, attention_mask, position_ids
def generate_samples(model, tokenizer, args):
model.eval()
output_path = "./samples"
if not os.path.exists(output_path):
os.makedirs(output_path)
output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
with torch.no_grad(), open(output_path, "w") as output:
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
if terminate_runs == 1:
return
start_time = time.time()
if args.block_lm:
mems = []
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
mask_positions = []
for token in mask_tokens:
mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
mask_positions.sort()
if args.no_block_position:
for mask_position in mask_positions:
position_ids[0, mask_position + 1:] += args.out_seq_length
_, *mems = model(tokens, position_ids, attention_mask, *mems)
for mask_position in mask_positions:
if args.no_block_position:
position = position_ids[0, mask_position].item()
else:
position = mask_position
if args.num_beams > 1:
strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
length_penalty=args.length_penalty, end_tokens=end_tokens,
no_repeat_ngram_size=args.no_repeat_ngram_size,
min_tgt_length=args.min_tgt_length)
else:
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
end_tokens=end_tokens)
new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems,
end_tokens=end_tokens)
tokens = torch.cat((tokens, new_tokens), dim=1)
output_tokens_list = tokens.view(-1).contiguous()
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
trim_decode_tokens = decode_tokens
print("\nGLM:", trim_decode_tokens, flush=True)
output.write(trim_decode_tokens + "\n")
torch.distributed.barrier(group=mpu.get_model_parallel_group())
def main(args):
initialize_distributed(args)
tokenizer = prepare_tokenizer(args)
# build model
model = GLMModel(args)
if args.fp16:
model = model.half()
model = model.to(args.device)
load_checkpoint(model, args)
set_random_seed(args.seed)
model.eval()
generate_samples(model, tokenizer, args)
if __name__ == "__main__":
args = get_args()
with torch.no_grad():
main(args)
# %%
coco_30k = '/workspace/dm/SwissArmyTransformer/coco30k.txt'
with open(coco_30k, 'r') as fin:
lines = fin.readlines()
import os
from posixpath import join
import shutil
prefix0 = '/workspace/dm/SwissArmyTransformer/coco_samples'
prefix1 = '/dataset/fd5061f6/mingding/SwissArmyTransformer/coco_samples'
cnt = 0
with open('coco_select.txt', 'w') as fout:
for i, line in enumerate(lines):
_id, text = line.strip().split('\t')
if i % 200 == 0:
print(i, cnt)
src = os.path.join(prefix1, _id)
if not os.path.exists(src):
src = os.path.join(prefix0, _id)
assert os.path.exists(src), _id
fout.write(
'\t'.join([text] + [
os.path.join(src, f'{i}.jpg')
for i in range(60)
]) + '\n'
)
\ No newline at end of file
# %%
# import torch
# old = torch.load('pretrained/cogview/cogview-caption/30000/mp_rank_00_model_states.pt.sat1', map_location='cpu')
# old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight']
# del old['module']['word_embeddings.weight']
# from model.base_model import BaseModel
# import argparse
# import os
# args = argparse.Namespace(
# num_layers=48,
# vocab_size=58240,
# hidden_size=2560,
# num_attention_heads=40,
# max_sequence_length=1089,
# hidden_dropout=0.1,
# attention_dropout=0.1,
# checkpoint_activations=True,
# checkpoint_num_layers=1,
# sandwich_ln=True,
# model_parallel_size=1,
# world_size=1,
# rank=0
# )
# init_method = 'tcp://'
# master_ip = os.getenv('MASTER_ADDR', 'localhost')
# master_port = os.getenv('MASTER_PORT', '6000')
# init_method += master_ip + ':' + master_port
# torch.distributed.init_process_group(
# backend='nccl',
# world_size=args.world_size, rank=args.rank,init_method=init_method)
# import mpu
# # Set the model-parallel / data-parallel communicators.
# mpu.initialize_model_parallel(args.model_parallel_size)
# print('bg')
# model = BaseModel(args)
# # %%
# missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False)
# torch.save(old, 'pretrained/cogview/cogview-caption/30000/mp_rank_00_model_states.pt')
# %%
import torch
old = torch.load('/dataset/fd5061f6/english_data/checkpoints/blocklm-10b-1024/126000/mp_rank_00_model_states.pt', map_location='cpu')
# old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight']
# del old['module']['word_embeddings.weight']
#%%
import torch
from model.cuda2d_model import Cuda2dModel
import argparse
import os
args = argparse.Namespace(
num_layers=48,
vocab_size=58240,
hidden_size=2560,
num_attention_heads=40,
max_sequence_length=1089,
hidden_dropout=0.1,
attention_dropout=0.1,
checkpoint_activations=True,
checkpoint_num_layers=1,
sandwich_ln=True,
model_parallel_size=1,
world_size=1,
rank=0,
new_sequence_length=1089+4096,
layout='0,64,1088,5184',
kernel_size=9,
kernel_size2=7
)
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend='nccl',
world_size=args.world_size, rank=args.rank,init_method=init_method)
import mpu
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
print('bg')
#%%
model = Cuda2dModel(args)
#%%
oldm = old['module']
for k in list(oldm.keys()):
if k.startswith('mixins.0'):
new_k = k.replace('mixins.0', 'mixins.extra_position_embedding')
elif k.startswith('mixins.1'):
new_k = k.replace('mixins.1', 'mixins.attention_plus')
else:
continue
oldm[new_k] = oldm[k]
del oldm[k]
#%%
old['module']['mixins.0.position_embeddings.weight'] = old['module']['transformer.position_embeddings_plus.weight']
del old['module']['transformer.position_embeddings_plus.weight']
for i in range(48):
old['module'][f'mixins.1.query_key_value.{i}.weight'] = \
old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight']
del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.weight']
old['module'][f'mixins.1.query_key_value.{i}.bias'] = \
old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias']
del old['module'][f'transformer.layers.{i}.attention.query_key_value_plus.bias']
old['module'][f'mixins.1.dense.{i}.weight'] = \
old['module'][f'transformer.layers.{i}.attention.dense_plus.weight']
del old['module'][f'transformer.layers.{i}.attention.dense_plus.weight']
old['module'][f'mixins.1.dense.{i}.bias'] = \
old['module'][f'transformer.layers.{i}.attention.dense_plus.bias']
del old['module'][f'transformer.layers.{i}.attention.dense_plus.bias']
# %%
missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=False)
# %%
torch.save(old, 'pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt')
# # %%
# import torch
# old = torch.load("/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt", map_location='cpu')
# # %%
# from collections import OrderedDict
# new_ckpt = OrderedDict()
# for k,v in old['state_dict'].items():
# new_ckpt[k] = v.detach()
# torch.save(new_ckpt, 'pretrained/vqvae/l1+ms-ssim+revd_percep.pt')
# # %%
# %%
old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight']
del old['module']['word_embeddings.weight']
#%%
import torch
from model.glm_model import GLMModel
import argparse
import os
args = argparse.Namespace(
num_layers=48,
vocab_size=50304,
hidden_size=4096,
num_attention_heads=64,
max_sequence_length=1025,
hidden_dropout=0.1,
attention_dropout=0.1,
checkpoint_activations=True,
checkpoint_num_layers=1,
sandwich_ln=False,
model_parallel_size=1,
world_size=1,
rank=0
)
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend='nccl',
world_size=args.world_size, rank=args.rank,init_method=init_method)
import mpu
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
print('bg')
# %%
model = GLMModel(args)
# %%
old['module']['mixins.block_position_embedding.block_position_embeddings.weight'] = old['module']['transformer.block_position_embeddings.weight']
del old['module']['transformer.block_position_embeddings.weight']
# %%
missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=True)
# %%
import os
os.makedirs('pretrained/glm/glm-en-10b/250000', exist_ok=True)
torch.save(old, 'pretrained/glm/glm-en-10b/250000/mp_rank_00_model_states.pt')
# %%
setup.py 0 → 100644
# Copyright (c) Ming Ding, et al. in KEG, Tsinghua University.
#
# LICENSE file in the root directory of this source tree.
import json
import sys
import os
from pathlib import Path
from setuptools import find_packages, setup
def _requirements():
return Path("requirements.txt").read_text()
setup(
name="SwissArmyTransformer",
version=0.1,
description="A transformer-based framework with finetuning as the first class citizen.",
long_description=Path("README.md").read_text(),
long_description_content_type="text/markdown",
install_requires=_requirements(),
entry_points={},
packages=find_packages(),
url="https://github.com/THUDM/SwissArmyTransformer",
author="Ming Ding, et al.",
author_email="dm_thu@qq.com",
scripts={},
include_package_data=True,
python_requires=">=3.5",
license="Apache 2.0 license"
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment