diff --git a/docs/inference.md b/docs/inference.md index 47aa0f9822c87b233cba15cb6458b1092f9ad7c1..67ee3dca697a4a17d0d12404cb3b885fe52f91ae 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -27,6 +27,21 @@ inference/samsum_prompt.txt ... ``` +**Note** +Currently pad token by default in [HuggingFace Tokenizer is `None`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L110). We add the padding token as a special token to the tokenizer, which in this case requires to resize the token_embeddings as shown below: + +```python +tokenizer.add_special_tokens( + { + + "pad_token": "<PAD>", + } + ) +model.resize_token_embeddings(model.config.vocab_size + 1) +``` +Padding would be required for batch inference. In this this [example](../inference/inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference. + + **Chat completion** The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example: diff --git a/inference/inference.py b/inference/inference.py index c010c07ca784a96abb6c43d01ac8a79fe8505da2..81668e3fb5b748a37e0f657c5fa43c8fc03541a8 100644 --- a/inference/inference.py +++ b/inference/inference.py @@ -31,7 +31,8 @@ def main( length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs - enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 + enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 + max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts. use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): @@ -76,10 +77,11 @@ def main( "pad_token": "<PAD>", } ) + model.resize_token_embeddings(model.config.vocab_size + 1) safety_checker = get_safety_checker(enable_azure_content_safety, enable_sensitive_topics, - enable_saleforce_content_safety, + enable_salesforce_content_safety, ) # Safety check of the user prompt @@ -94,10 +96,15 @@ def main( if not is_safe: print(method) print(report) - print("Skipping the inferece as the prompt is not safe.") + print("Skipping the inference as the prompt is not safe.") sys.exit(1) # Exit the program with an error status + + if peft_model: + model = load_peft_model(model, peft_model) + + model.eval() + batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt") - batch = tokenizer(user_prompt, return_tensors="pt") batch = {k: v.to("cuda") for k, v in batch.items()} start = time.perf_counter() with torch.no_grad(): diff --git a/inference/safety_utils.py b/inference/safety_utils.py index 9c6d0c36163115b78a8d15c5eef85ab03151ce93..bc321eb929df52b2c87e1ed0b52f7061d468094c 100644 --- a/inference/safety_utils.py +++ b/inference/safety_utils.py @@ -154,14 +154,14 @@ class AzureSaftyChecker(object): # Function to determine which safety checker to use based on the options selected def get_safety_checker(enable_azure_content_safety, enable_sensitive_topics, - enable_saleforce_content_safety, + enable_salesforce_content_safety, ): safety_checker = [] if enable_azure_content_safety: safety_checker.append(AzureSaftyChecker()) if enable_sensitive_topics: safety_checker.append(AuditNLGSensitiveTopics()) - if enable_saleforce_content_safety: + if enable_salesforce_content_safety: safety_checker.append(SalesforceSafetyChecker()) return safety_checker