diff --git a/README.md b/README.md index 5dbedc3f05ecc40931baf29452f5f71dc193b8c3..370be8974ef3d3d9fe77c080844261c2c3b29115 100644 --- a/README.md +++ b/README.md @@ -107,13 +107,21 @@ torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --use_ Here we use FSDP as discussed in the next section which can be used along with PEFT methods. To make use of PEFT methods with FSDP make sure to pass `use_peft` and `peft_method` args along with `enable_fsdp`. Here we are using `BF16` for training. +## Flash Attention and Xformer Memory Efficient Kernels + +Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/). + +```bash +torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels +``` + ### Fine-tuning using FSDP Only If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs. ```bash -torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned +torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels ``` diff --git a/configs/training.py b/configs/training.py index 4c50372dafbabfbd4acc032476d8111e635bb7f0..7b0e82d44363e563a2420ef6a2e7977ad8c539c8 100644 --- a/configs/training.py +++ b/configs/training.py @@ -32,6 +32,7 @@ class train_config: dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP + use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels diff --git a/docs/inference.md b/docs/inference.md index c08c91f10942f7d094e8c5c2bac6ca630b1f052d..47aa0f9822c87b233cba15cb6458b1092f9ad7c1 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -34,6 +34,18 @@ The inference folder also includes a chat completion example, that adds built-in python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json --quantization --use_auditnlg ``` + +## Flash Attention and Xformer Memory Efficient Kernels + +Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/). + +```bash +python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json --quantization --use_auditnlg --use_fast_kernels + +python inference/inference.py --model_name <training_config.output_dir> --peft_model <training_config.output_dir> --prompt_file <test_prompt_file> --use_auditnlg --use_fast_kernels + +``` + ## Loading back FSDP checkpoints In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above. diff --git a/docs/multi_gpu.md b/docs/multi_gpu.md index 67c551eaa533665d47056d9cd80dcc57d81b60a4..5c8fdf353fb2b77cb8facaaeff2ba9bb1a8e511a 100644 --- a/docs/multi_gpu.md +++ b/docs/multi_gpu.md @@ -44,6 +44,13 @@ The args used in the command above are: We use `torchrun` here to spawn multiple processes for FSDP. +## Flash Attention and Xformer Memory Efficient Kernels + +Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/). + +```bash +torchrun --nnodes 1 --nproc_per_node 4 ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels +``` ### Fine-tuning using FSDP Only @@ -51,7 +58,7 @@ If interested in running full parameter finetuning without making use of PEFT me ```bash -torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 +torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 --use_fast_kernels ``` diff --git a/inference/chat_completion.py b/inference/chat_completion.py index bc5311d62ac735ced4043357ebae0144cf335884..d5c8378bdbbdc5f2db3fc33ef032dc47e23a07a4 100644 --- a/inference/chat_completion.py +++ b/inference/chat_completion.py @@ -34,6 +34,7 @@ def main( enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 + use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): if prompt_file is not None: @@ -59,6 +60,18 @@ def main( model = load_model(model_name, quantization) if peft_model: model = load_peft_model(model, peft_model) + if use_fast_kernels: + """ + Setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up inference when used for batched inputs. + """ + try: + from optimum.bettertransformer import BetterTransformer + model = BetterTransformer.transform(model) + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens( { diff --git a/inference/inference.py b/inference/inference.py index 469bde4041a74af1ef94bd7cf6024b09c3e4a83f..c010c07ca784a96abb6c43d01ac8a79fe8505da2 100644 --- a/inference/inference.py +++ b/inference/inference.py @@ -32,6 +32,7 @@ def main( enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 + use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): if prompt_file is not None: @@ -51,6 +52,23 @@ def main( torch.manual_seed(seed) model = load_model(model_name, quantization) + if peft_model: + model = load_peft_model(model, peft_model) + + model.eval() + + if use_fast_kernels: + """ + Setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up inference when used for batched inputs. + """ + try: + from optimum.bettertransformer import BetterTransformer + model = BetterTransformer.transform(model) + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens( { @@ -79,11 +97,6 @@ def main( print("Skipping the inferece as the prompt is not safe.") sys.exit(1) # Exit the program with an error status - if peft_model: - model = load_peft_model(model, peft_model) - - model.eval() - batch = tokenizer(user_prompt, return_tensors="pt") batch = {k: v.to("cuda") for k, v in batch.items()} start = time.perf_counter() diff --git a/llama_finetuning.py b/llama_finetuning.py index ccf8c68457d13597cfe96768b6f8808b94fe243b..b7c2dc60355c203fae73c808c8a595baeeb4d5c3 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -24,7 +24,6 @@ from transformers import ( BitsAndBytesConfig ) import torch.distributed as dist - # Unused imports removed from utils.train_utils import ( set_tokenizer_params, @@ -95,7 +94,17 @@ def main(**kwargs): load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, ) - + if train_config.enable_fsdp and train_config.use_fast_kernels: + """ + For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up fine-tuning. + """ + try: + from optimum.bettertransformer import BetterTransformer + model = BetterTransformer.transform(model) + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) # Prepare the model for int8 training if quantization is enabled diff --git a/requirements.txt b/requirements.txt index 9258c3e4e75e7714034a4fd177cad7e4615dffef..6c499b1b9275547c40cb1d16310a5da0cbec1b1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ transformers>=4.31.0 sentencepiece py7zr scipy - +optimum diff --git a/scripts/spellcheck_conf/wordlist.txt b/scripts/spellcheck_conf/wordlist.txt index 311e24215712cd76e3cb771be36e96172c3471a8..27c7323cdbcfb06bfca606234c4d2ba31c577b9d 100644 --- a/scripts/spellcheck_conf/wordlist.txt +++ b/scripts/spellcheck_conf/wordlist.txt @@ -1090,5 +1090,34 @@ intra nightlies recenly uncomment +BFloat +DDP +LLM +Xformer +accuracies +activations +anyprecision +aplaca +assembels +boolean +checkpoining +defatults +gradinets +itermediate +recommond +scaler +sharding +slurm +summarization +theJfleg +xA +Jupyter +LLM +Xformer +dataset's +jupyter +mutli +summarization +xA Sanitization -tokenization \ No newline at end of file +tokenization diff --git a/utils/train_utils.py b/utils/train_utils.py index 03087de7b27d9b3bed1bf2d7d424175c1e503c2b..8113ef173bb748a083e88a899ac8033f131e5421 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -5,6 +5,7 @@ import os import sys from typing import List import yaml +import time import fire import torch @@ -73,9 +74,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_loss = [] val_prep = [] val_loss =[] + epoch_times = [] + checkpoint_times = [] results = {} best_val_loss = float("inf") for epoch in range(train_config.num_epochs): + epoch_start_time = time.perf_counter() with MemoryTrace() as memtrace: # track the memory usage model.train() total_loss = 0.0 @@ -106,7 +110,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print(f"\n step {step} is completed and loss is {loss.detach().float()}") else: print(f"\n step {step} is completed and loss is {loss.detach().float()}") - + epoch_end_time = time.perf_counter()-epoch_start_time + epoch_times.append(epoch_end_time) # Reducing total_loss across all devices if there's more than one CUDA device if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) @@ -117,6 +122,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_prep.append(train_perplexity) train_loss.append(train_epoch_loss) + if train_config.enable_fsdp: if rank==0: print(f"Max CUDA memory allocated was {memtrace.peak} GB") @@ -136,6 +142,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.run_validation: eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer) + checkpoint_start_time = time.perf_counter() if train_config.save_model and eval_epoch_loss < best_val_loss: if train_config.enable_fsdp: dist.barrier() @@ -176,7 +183,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print("=====================================================") if train_config.enable_fsdp: dist.barrier() - + checkpoint_end_time = time.perf_counter() - checkpoint_start_time + checkpoint_times.append(checkpoint_end_time) if eval_epoch_loss < best_val_loss: best_val_loss = eval_epoch_loss if train_config.enable_fsdp: @@ -189,10 +197,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: if rank==0: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}") + print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s") else: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}") - + print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s") + avg_epoch_time = sum(epoch_times)/ len(epoch_times) + avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) avg_train_prep = sum(train_prep)/len(train_prep) avg_train_loss = sum(train_loss)/len(train_loss) if train_config.run_validation: @@ -204,7 +213,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.run_validation: results['avg_eval_prep'] = avg_eval_prep results['avg_eval_loss'] = avg_eval_loss - + results["avg_epoch_time"] = avg_epoch_time + results["avg_checkpoint_time"] = avg_checkpoint_time + #saving the training params including fsdp setting for reference. if train_config.enable_fsdp and not train_config.use_peft: save_train_params(train_config, fsdp_config, rank)