diff --git a/pred.py b/pred.py index bf5a27ea9f0552520d0972c5e5d07ea7bc7dce62..ae1ec6fae31a86d5576f09493cd0c05f482403bd 100644 --- a/pred.py +++ b/pred.py @@ -11,7 +11,7 @@ def build_chat(tokenizer, prompt): def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device): preds = [] - for json_obj in tqdm(data[:10]): + for json_obj in tqdm(data): prompt = prompt_format.format(**json_obj) # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]