Skip to content
Snippets Groups Projects
Commit 754b5d22 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri
Browse files

clean up and typo fixes

parent 9b0eae40
Branches
Tags
No related merge requests found
......@@ -31,7 +31,7 @@ def main(
length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
**kwargs
):
......@@ -59,10 +59,11 @@ def main(
"pad_token": "<PAD>",
}
)
model.resize_token_embeddings(model.config.vocab_size + 1)
safety_checker = get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics,
enable_saleforce_content_safety,
enable_salesforce_content_safety,
)
# Safety check of the user prompt
......@@ -77,7 +78,7 @@ def main(
if not is_safe:
print(method)
print(report)
print("Skipping the inferece as the prompt is not safe.")
print("Skipping the inference as the prompt is not safe.")
sys.exit(1) # Exit the program with an error status
if peft_model:
......@@ -85,7 +86,6 @@ def main(
model.eval()
batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt")
model.resize_token_embeddings(model.config.vocab_size + 1)
batch = {k: v.to("cuda") for k, v in batch.items()}
start = time.perf_counter()
with torch.no_grad():
......
......@@ -154,14 +154,14 @@ class AzureSaftyChecker(object):
# Function to determine which safety checker to use based on the options selected
def get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics,
enable_saleforce_content_safety,
enable_salesforce_content_safety,
):
safety_checker = []
if enable_azure_content_safety:
safety_checker.append(AzureSaftyChecker())
if enable_sensitive_topics:
safety_checker.append(AuditNLGSensitiveTopics())
if enable_saleforce_content_safety:
if enable_salesforce_content_safety:
safety_checker.append(SalesforceSafetyChecker())
return safety_checker
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment