Skip to content
Snippets Groups Projects
Commit 2d9f4796 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri
Browse files

fixing the output format

parent 1e8ea70b
No related branches found
No related tags found
No related merge requests found
......@@ -42,13 +42,13 @@ def main(
prompt_file
), f"Provided Prompt file does not exist {prompt_file}"
with open(prompt_file, "r") as f:
user_prompt = "\n".join(f.readlines())
user_prompt = f.read()
elif not sys.stdin.isatty():
user_prompt = "\n".join(sys.stdin.readlines())
else:
print("No user prompt provided. Exiting.")
sys.exit(1)
# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
......@@ -121,7 +121,7 @@ def main(
)
e2e_inference_time = (time.perf_counter()-start)*1000
print(f"the inference time is {e2e_inference_time} ms")
filling = tokenizer.decode(outputs[0], skip_special_tokens=True)
filling = tokenizer.batch_decode(outputs[:, batch["input_ids"].shape[1]:], skip_special_tokens=True)[0]
# Safety check of the model output
safety_results = [check(filling) for check in safety_checker]
are_safe = all([r[1] for r in safety_results])
......
'''def remove_non_ascii(s: str) -> str:
def remove_non_ascii(s: str) -> str:
""" <FILL_ME>
return result
'''
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment