Skip to content
Snippets Groups Projects
Commit 564ef2f6 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri
Browse files

remove padding logic

parent 277a292f
No related branches found
No related tags found
No related merge requests found
...@@ -33,7 +33,6 @@ def main( ...@@ -33,7 +33,6 @@ def main(
enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs **kwargs
): ):
...@@ -70,14 +69,6 @@ def main( ...@@ -70,14 +69,6 @@ def main(
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(
{
"pad_token": "<PAD>",
}
)
model.resize_token_embeddings(model.config.vocab_size + 1)
safety_checker = get_safety_checker(enable_azure_content_safety, safety_checker = get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics, enable_sensitive_topics,
enable_salesforce_content_safety, enable_salesforce_content_safety,
...@@ -98,7 +89,7 @@ def main( ...@@ -98,7 +89,7 @@ def main(
print("Skipping the inference as the prompt is not safe.") print("Skipping the inference as the prompt is not safe.")
sys.exit(1) # Exit the program with an error status sys.exit(1) # Exit the program with an error status
batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt") batch = tokenizer(user_prompt, return_tensors="pt")
batch = {k: v.to("cuda") for k, v in batch.items()} batch = {k: v.to("cuda") for k, v in batch.items()}
start = time.perf_counter() start = time.perf_counter()
......
...@@ -33,7 +33,6 @@ def main( ...@@ -33,7 +33,6 @@ def main(
enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs **kwargs
): ):
...@@ -70,13 +69,6 @@ def main( ...@@ -70,13 +69,6 @@ def main(
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(
{
"pad_token": "<PAD>",
}
)
model.resize_token_embeddings(model.config.vocab_size + 1)
safety_checker = get_safety_checker(enable_azure_content_safety, safety_checker = get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics, enable_sensitive_topics,
...@@ -98,7 +90,7 @@ def main( ...@@ -98,7 +90,7 @@ def main(
print("Skipping the inference as the prompt is not safe.") print("Skipping the inference as the prompt is not safe.")
sys.exit(1) # Exit the program with an error status sys.exit(1) # Exit the program with an error status
batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt") batch = tokenizer(user_prompt, return_tensors="pt")
batch = {k: v.to("cuda") for k, v in batch.items()} batch = {k: v.to("cuda") for k, v in batch.items()}
start = time.perf_counter() start = time.perf_counter()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment