Skip to content
Snippets Groups Projects
Commit 8584bf91 authored by Ming Ding's avatar Ming Ding
Browse files

prepare for large-scale generation

parent a2625727
No related branches found
No related tags found
No related merge requests found
......@@ -14,4 +14,5 @@ pretrained/
*.jpg
*.jpeg
input*.txt
samples*
\ No newline at end of file
*samples*/
......@@ -246,6 +246,8 @@ def get_args(args_list=None):
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
if args.local_rank is not None:
args.device = args.local_rank
args.model_parallel_size = min(args.model_parallel_size, args.world_size)
if args.rank == 0:
......
......@@ -84,7 +84,6 @@ def filling_sequence_cuda2d(
for x in range(min(ll, step_cnt - warmup_steps)):
y = step_cnt - warmup_steps - x - 1
if y < rr:
print(x,y)
unfixed[..., -(layout[-1] - layout[-2]):].view(
batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
......
......@@ -42,7 +42,7 @@ def main(args):
def process(raw_text):
if args.with_id:
query_id, raw_text = raw_text.split()
query_id, raw_text = raw_text.split('\t')
print('raw text: ', raw_text)
text = query_template.format(raw_text)
seq = tokenizer.parse_query(text, img_size=args.img_size)
......
......@@ -52,7 +52,7 @@ def main(args):
def process(raw_text):
if args.with_id:
query_id, raw_text = raw_text.split()
query_id, raw_text = raw_text.split('\t')
print('raw text: ', raw_text)
text = query_template.format(raw_text)
seq = tokenizer.parse_query(text, img_size=args.img_size)
......@@ -62,7 +62,7 @@ def main(args):
txt_len = seq.index(tokenizer['[BASE]'])
log_attention_weights = torch.zeros(len(seq), len(seq),
device=args.device, dtype=torch.half if args.fp16 else torch.float32)
log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args
log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.6 # TODO args
# generation
seq = torch.cuda.LongTensor(seq, device=args.device)
......@@ -77,7 +77,9 @@ def main(args):
)
imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1025:-1].tolist())) for x in output0]
blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(args.device), add_normalization=True) # [batch_size, 4096]
output1 = filling_sequence_cuda2d(model, output0, blur64,
len_tim = output0.shape[0]
for tim2 in range(0, len_tim, 4):
output1 = filling_sequence_cuda2d(model, output0[tim2: tim2+4], blur64[tim2: tim2+4],
warmup_steps=3, block_hw=(4, 4),
strategy=strategy1
)
......
#!/bin/bash
NUM_WORKERS=4
NUM_GPUS_PER_WORKER=8
MP_SIZE=1
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="hostfile"
# HOST_FILE_PATH="hostfile_single"
CHECKPOINT_PATH=pretrained/cogview/cogview2-base
NLAYERS=48
NHIDDEN=2560
NATT=40
MAXSEQLEN=1089
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
MPSIZE=1
#SAMPLING ARGS
TEMP=1.03
TOPK=200
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
gpt_options=" \
--tokenizer-type cogview \
--img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
--mode inference \
--distributed-backend nccl \
--max-sequence-length 1089 \
--sandwich-ln \
--fp16 \
--model-parallel-size $MPSIZE \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--load $CHECKPOINT_PATH \
--num-attention-heads $NATT \
--temperature $TEMP \
--top_k $TOPK \
--sandwich-ln \
--input-source ./coco30k.txt \
--output-path coco_samples \
--batch-size 60 \
--max-inference-batch-size 12 \
--with-id \
"
run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} inference_cogview2.py $@ ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}
set +x
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment