diff --git a/docs/inference.md b/docs/inference.md index 67ee3dca697a4a17d0d12404cb3b885fe52f91ae..2f15c9dbd0814dc6b30128580dc79d6d03f06c7b 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -41,6 +41,29 @@ model.resize_token_embeddings(model.config.vocab_size + 1) ``` Padding would be required for batch inference. In this this [example](../inference/inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference. +**Code Llama** + +Code llama was recently released with three flavors, base-model tha support multiple programming languages, Python fine-tuned model and an instruction fine-tuned and aligned variation of Code Llama, please read more [here](https://ai.meta.com/blog/code-llama-large-language-model-coding/). + +Find the scripts to run Code Llama [here](../inference/code-llama/), where there are two examples of running code completion and infilling. + +**Note** Please find the right model on HF side [here](https://huggingface.co/codellama). + +To run the code completion example: + +```bash + +python code_completion_example.py --model_name MODEL_NAME --prompt_file code_completion_prompt.txt --temperature 0.2 --top_p 0.9 + +``` + +To run the code infilling example: + +```bash + +python code_infilling_example.py --model_name MODEL_NAME --prompt_file code_infilling_prompt.txt --temperature 0.2 --top_p 0.9 + +``` **Chat completion** The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example: diff --git a/inference/code-llama/code_completion_example.py b/inference/code-llama/code_completion_example.py new file mode 100644 index 0000000000000000000000000000000000000000..106c54ad655077da79df9295ed261c75161856d2 --- /dev/null +++ b/inference/code-llama/code_completion_example.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# from accelerate import init_empty_weights, load_checkpoint_and_dispatch + +import fire +import torch +import os +import sys +import time +from typing import List + +from transformers import CodeLlamaTokenizer +sys.path.append("..") +from safety_utils import get_safety_checker +from model_utils import load_model, load_peft_model, load_llama_from_config + +def main( + model_name, + peft_model: str=None, + quantization: bool=False, + max_new_tokens =100, #The maximum numbers of tokens to generate + prompt_file: str=None, + seed: int=42, #seed value for reproducibility + do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise. + min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens + use_cache: bool=True, #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: float=1.0, # [optional] The value used to modulate the next token probabilities. + top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering. + repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. + enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api + enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs + enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 + max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts. + use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + **kwargs +): + if prompt_file is not None: + assert os.path.exists( + prompt_file + ), f"Provided Prompt file does not exist {prompt_file}" + with open(prompt_file, "r") as f: + user_prompt = "\n".join(f.readlines()) + elif not sys.stdin.isatty(): + user_prompt = "\n".join(sys.stdin.readlines()) + else: + print("No user prompt provided. Exiting.") + sys.exit(1) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(seed) + torch.manual_seed(seed) + + model = load_model(model_name, quantization) + if peft_model: + model = load_peft_model(model, peft_model) + + model.eval() + + if use_fast_kernels: + """ + Setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up inference when used for batched inputs. + """ + try: + from optimum.bettertransformer import BetterTransformer + model = BetterTransformer.transform(model) + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + + tokenizer = CodeLlamaTokenizer.from_pretrained(model_name) + tokenizer.add_special_tokens( + { + + "pad_token": "<PAD>", + } + ) + model.resize_token_embeddings(model.config.vocab_size + 1) + + safety_checker = get_safety_checker(enable_azure_content_safety, + enable_sensitive_topics, + enable_salesforce_content_safety, + ) + + # Safety check of the user prompt + safety_results = [check(user_prompt) for check in safety_checker] + are_safe = all([r[1] for r in safety_results]) + if are_safe: + print("User prompt deemed safe.") + print(f"User prompt:\n{user_prompt}") + else: + print("User prompt deemed unsafe.") + for method, is_safe, report in safety_results: + if not is_safe: + print(method) + print(report) + print("Skipping the inference as the prompt is not safe.") + sys.exit(1) # Exit the program with an error status + + batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt") + + batch = {k: v.to("cuda") for k, v in batch.items()} + start = time.perf_counter() + with torch.no_grad(): + outputs = model.generate( + **batch, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_p=top_p, + temperature=temperature, + min_length=min_length, + use_cache=use_cache, + top_k=top_k, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + **kwargs + ) + e2e_inference_time = (time.perf_counter()-start)*1000 + print(f"the inference time is {e2e_inference_time} ms") + output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Safety check of the model output + safety_results = [check(output_text) for check in safety_checker] + are_safe = all([r[1] for r in safety_results]) + if are_safe: + print("User input and model output deemed safe.") + print(f"Model output:\n{output_text}") + else: + print("Model output deemed unsafe.") + for method, is_safe, report in safety_results: + if not is_safe: + print(method) + print(report) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/inference/code-llama/code_completion_prompt.txt b/inference/code-llama/code_completion_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..8e184e2fe3fd374d17de604b16bf9d3b1489a10e --- /dev/null +++ b/inference/code-llama/code_completion_prompt.txt @@ -0,0 +1,7 @@ +import argparse + +def main(string: str): + print(string) + print(string[::-1]) + +if __name__ == "__main__": \ No newline at end of file diff --git a/inference/code-llama/code_infilling_example.py b/inference/code-llama/code_infilling_example.py new file mode 100644 index 0000000000000000000000000000000000000000..57a367d27627ff9817f0984280c6e071c7f2313c --- /dev/null +++ b/inference/code-llama/code_infilling_example.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# from accelerate import init_empty_weights, load_checkpoint_and_dispatch + +import fire +import torch +import os +import sys +import time +from typing import List + +from transformers import CodeLlamaTokenizer +sys.path.append("..") +from safety_utils import get_safety_checker +from model_utils import load_model, load_peft_model, load_llama_from_config + +def main( + model_name, + peft_model: str=None, + quantization: bool=False, + max_new_tokens =100, #The maximum numbers of tokens to generate + prompt_file: str=None, + seed: int=42, #seed value for reproducibility + do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise. + min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens + use_cache: bool=True, #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: float=1.0, # [optional] The value used to modulate the next token probabilities. + top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering. + repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. + enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api + enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs + enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 + max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts. + use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + **kwargs +): + if prompt_file is not None: + assert os.path.exists( + prompt_file + ), f"Provided Prompt file does not exist {prompt_file}" + with open(prompt_file, "r") as f: + user_prompt = "\n".join(f.readlines()) + elif not sys.stdin.isatty(): + user_prompt = "\n".join(sys.stdin.readlines()) + else: + print("No user prompt provided. Exiting.") + sys.exit(1) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(seed) + torch.manual_seed(seed) + + model = load_model(model_name, quantization) + model.config.tp_size=1 + if peft_model: + model = load_peft_model(model, peft_model) + + model.eval() + + if use_fast_kernels: + """ + Setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up inference when used for batched inputs. + """ + try: + from optimum.bettertransformer import BetterTransformer + model = BetterTransformer.transform(model) + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + + tokenizer = CodeLlamaTokenizer.from_pretrained(model_name) + tokenizer.add_special_tokens( + { + + "pad_token": "<PAD>", + } + ) + model.resize_token_embeddings(model.config.vocab_size + 1) + + safety_checker = get_safety_checker(enable_azure_content_safety, + enable_sensitive_topics, + enable_salesforce_content_safety, + ) + + # Safety check of the user prompt + safety_results = [check(user_prompt) for check in safety_checker] + are_safe = all([r[1] for r in safety_results]) + if are_safe: + print("User prompt deemed safe.") + print(f"User prompt:\n{user_prompt}") + else: + print("User prompt deemed unsafe.") + for method, is_safe, report in safety_results: + if not is_safe: + print(method) + print(report) + print("Skipping the inference as the prompt is not safe.") + sys.exit(1) # Exit the program with an error status + + batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt") + batch = {k: v.to("cuda") for k, v in batch.items()} + + start = time.perf_counter() + with torch.no_grad(): + outputs = model.generate( + **batch, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_p=top_p, + temperature=temperature, + min_length=min_length, + use_cache=use_cache, + top_k=top_k, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + **kwargs + ) + e2e_inference_time = (time.perf_counter()-start)*1000 + print(f"the inference time is {e2e_inference_time} ms") + filling = tokenizer.decode(outputs[0], skip_special_tokens=True) + # Safety check of the model output + safety_results = [check(filling) for check in safety_checker] + are_safe = all([r[1] for r in safety_results]) + if are_safe: + print("User input and model output deemed safe.") + print(user_prompt.replace("<FILL_ME>", filling)) + else: + print("Model output deemed unsafe.") + for method, is_safe, report in safety_results: + if not is_safe: + print(method) + print(report) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/inference/code-llama/code_infilling_prompt.txt b/inference/code-llama/code_infilling_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ad76cacacc1406439a95b9684275c80433a72ce --- /dev/null +++ b/inference/code-llama/code_infilling_prompt.txt @@ -0,0 +1,4 @@ +'''def remove_non_ascii(s: str) -> str: + """ <FILL_ME> + return result +''' \ No newline at end of file