Skip to content
Snippets Groups Projects
Commit c9891200 authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Use beam search strategy for inference_glm.py

parent 41a21220
Branches dev
No related tags found
No related merge requests found
......@@ -61,8 +61,12 @@ def main(args):
end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
# define function for each query
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens)
if args.num_beams > 1:
strategy = BeamSearchStrategy(num_beams=args.num_beams, length_penalty=args.length_penalty,
no_repeat_ngram_size=args.no_repeat_ngram_size)
else:
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
def process(raw_text):
if args.with_id:
query_id, raw_text = raw_text.split('\t')
......@@ -77,7 +81,7 @@ def main(args):
print('raw text: {}\n'.format(raw_text))
if len(seq) > args.max_sequence_length:
raise ValueError('text too long.')
# find mask tokens positions
# mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
# mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
......@@ -93,7 +97,7 @@ def main(args):
output_list = [seq]
# continually detect the first mark position
while True:
seq = output_list[0] # TODO find the best one
seq = output_list[0] # TODO find the best one
# detect
mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
......@@ -105,7 +109,7 @@ def main(args):
pass
if mask_position == len(seq):
break
get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq))
output_list = []
for tim in range(max(args.batch_size // mbz, 1)):
......@@ -113,12 +117,12 @@ def main(args):
seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length - len(seq) - 1),
device=args.device)
output, _mems = filling_sequence(model, input_seq,
batch_size=min(args.batch_size, mbz),
strategy=strategy,
log_attention_weights=None,
get_masks_and_position_ids=get_func
) # we don't use mems, fill back
if isinstance(output, torch.Tensor): # different strategies
batch_size=min(args.batch_size, mbz),
strategy=strategy,
log_attention_weights=None,
get_masks_and_position_ids=get_func
) # we don't use mems, fill back
if isinstance(output, torch.Tensor): # different strategies
output = list(output)
output_list.extend(output)
......@@ -147,7 +151,7 @@ def main(args):
else:
prefix = raw_text.replace('/', '')[:20]
full_path = timed_name(prefix, '.txt', args.output_path)
print(txts[0]) # print the first.
print(txts[0]) # print the first.
with open(full_path, 'w') as fout:
for txt in txts:
fout.write(txt + '\n')
......@@ -163,6 +167,6 @@ if __name__ == "__main__":
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
with torch.no_grad():
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment