diff --git a/README.md b/README.md index ab122be269381e696b1895c1ae31ab0684320f75..72f1312bba4125cec9d83a0bcf05f8157aab80ef 100644 --- a/README.md +++ b/README.md @@ -107,13 +107,21 @@ torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --use_ Here we use FSDP as discussed in the next section which can be used along with PEFT methods. To make use of PEFT methods with FSDP make sure to pass `use_peft` and `peft_method` args along with `enable_fsdp`. Here we are using `BF16` for training. +## Flash Attention and Xformer Memory Efficient Kernels + +Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/). + +```bash +torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels +``` + ### Fine-tuning using FSDP Only If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs. ```bash -torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned +torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels ``` diff --git a/configs/training.py b/configs/training.py index 4c50372dafbabfbd4acc032476d8111e635bb7f0..7b0e82d44363e563a2420ef6a2e7977ad8c539c8 100644 --- a/configs/training.py +++ b/configs/training.py @@ -32,6 +32,7 @@ class train_config: dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP + use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels diff --git a/docs/inference.md b/docs/inference.md index c08c91f10942f7d094e8c5c2bac6ca630b1f052d..47aa0f9822c87b233cba15cb6458b1092f9ad7c1 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -34,6 +34,18 @@ The inference folder also includes a chat completion example, that adds built-in python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json --quantization --use_auditnlg ``` + +## Flash Attention and Xformer Memory Efficient Kernels + +Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/). + +```bash +python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json --quantization --use_auditnlg --use_fast_kernels + +python inference/inference.py --model_name <training_config.output_dir> --peft_model <training_config.output_dir> --prompt_file <test_prompt_file> --use_auditnlg --use_fast_kernels + +``` + ## Loading back FSDP checkpoints In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above. diff --git a/docs/mutli_gpu.md b/docs/mutli_gpu.md index a4396deeade7438bf026f8ac41a2348b28b9438c..3fed23f557c46160b0827468f190cb34d2f9984c 100644 --- a/docs/mutli_gpu.md +++ b/docs/mutli_gpu.md @@ -44,6 +44,13 @@ The args used in the command above are: We use `torchrun` here to spawn multiple processes for FSDP. +## Flash Attention and Xformer Memory Efficient Kernels + +Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/). + +```bash +torchrun --nnodes 1 --nproc_per_node 4 ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels +``` ### Fine-tuning using FSDP Only @@ -51,7 +58,7 @@ If interested in running full parameter finetuning without making use of PEFT me ```bash -torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 +torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 --use_fast_kernels ``` diff --git a/inference/chat_completion.py b/inference/chat_completion.py index bc5311d62ac735ced4043357ebae0144cf335884..ea810816b7992947d3c922bd67428703f593b459 100644 --- a/inference/chat_completion.py +++ b/inference/chat_completion.py @@ -11,6 +11,7 @@ from typing import List from peft import PeftModel, PeftConfig from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM +from optimum.bettertransformer import BetterTransformer from safety_utils import get_safety_checker from model_utils import load_model, load_peft_model from chat_utils import read_dialogs_from_file, format_tokens @@ -34,6 +35,7 @@ def main( enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 + use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): if prompt_file is not None: @@ -57,6 +59,18 @@ def main( torch.cuda.manual_seed(seed) torch.manual_seed(seed) model = load_model(model_name, quantization) + if use_fast_kernels: + """ + Setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up inference when used for batched inputs. + """ + try: + from optimum.bettertransformer import BetterTransformer + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + + model = BetterTransformer.transform(model) if peft_model: model = load_peft_model(model, peft_model) tokenizer = LlamaTokenizer.from_pretrained(model_name) diff --git a/inference/inference.py b/inference/inference.py index 469bde4041a74af1ef94bd7cf6024b09c3e4a83f..7c4fb2007cd65f2cd7323aa139c037bd31e38114 100644 --- a/inference/inference.py +++ b/inference/inference.py @@ -32,6 +32,7 @@ def main( enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 + use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): if prompt_file is not None: @@ -51,6 +52,18 @@ def main( torch.manual_seed(seed) model = load_model(model_name, quantization) + if use_fast_kernels: + """ + Setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up inference when used for batched inputs. + """ + try: + from optimum.bettertransformer import BetterTransformer + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + + model = BetterTransformer.transform(model) tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens( { diff --git a/llama_finetuning.py b/llama_finetuning.py index ccf8c68457d13597cfe96768b6f8808b94fe243b..2b1d63fbff7b155a3e757b654e281b271f38ad0e 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -24,6 +24,7 @@ from transformers import ( BitsAndBytesConfig ) import torch.distributed as dist +from optimum.bettertransformer import BetterTransformer # Unused imports removed from utils.train_utils import ( @@ -95,7 +96,17 @@ def main(**kwargs): load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, ) - + if train_config.enable_fsdp and train_config.use_fast_kernels: + """ + For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up fine-tuning. + """ + try: + from optimum.bettertransformer import BetterTransformer + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + model = BetterTransformer.transform(model) print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) # Prepare the model for int8 training if quantization is enabled diff --git a/requirements.txt b/requirements.txt index 9258c3e4e75e7714034a4fd177cad7e4615dffef..6c499b1b9275547c40cb1d16310a5da0cbec1b1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ transformers>=4.31.0 sentencepiece py7zr scipy - +optimum diff --git a/utils/train_utils.py b/utils/train_utils.py index 3fa4c0cf1d34a3b1e6bbad010597216d7669d1e6..c051afd9e96dd59e6282e03d4b2283eea0277d85 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -5,6 +5,7 @@ import os import sys from typing import List import yaml +import time import fire import torch @@ -73,9 +74,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_loss = [] val_prep = [] val_loss =[] + epoch_times = [] + checkpoint_times = [] results = {} best_val_loss = float("inf") for epoch in range(train_config.num_epochs): + epoch_start_time = time.perf_counter() with MemoryTrace() as memtrace: # track the memory usage model.train() total_loss = 0.0 @@ -106,7 +110,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print(f"\n step {step} is completed and loss is {loss.detach().float()}") else: print(f"\n step {step} is completed and loss is {loss.detach().float()}") - + epoch_end_time = time.perf_counter()-epoch_start_time + epoch_times.append(epoch_end_time) # Reducing total_loss across all devices if there's more than one CUDA device if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) @@ -117,6 +122,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_prep.append(train_perplexity) train_loss.append(train_epoch_loss) + if train_config.enable_fsdp: if rank==0: print(f"Max CUDA memory allocated was {memtrace.peak} GB") @@ -136,6 +142,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.run_validation: eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer) + checkpoint_start_time = time.perf_counter() if train_config.save_model and eval_epoch_loss < best_val_loss: if train_config.enable_fsdp: dist.barrier() @@ -176,7 +183,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print("=====================================================") if train_config.enable_fsdp: dist.barrier() - + checkpoint_end_time = time.perf_counter() - checkpoint_start_time + checkpoint_times.append(checkpoint_end_time) if eval_epoch_loss < best_val_loss: best_val_loss = eval_epoch_loss if train_config.enable_fsdp: @@ -189,10 +197,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: if rank==0: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}") + print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s") else: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}") - + print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s") + avg_epoch_time = sum(epoch_times)/ len(epoch_times) + avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) avg_train_prep = sum(train_prep)/len(train_prep) avg_train_loss = sum(train_loss)/len(train_loss) if train_config.run_validation: @@ -204,7 +213,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.run_validation: results['avg_eval_prep'] = avg_eval_prep results['avg_eval_loss'] = avg_eval_loss - + results["avg_epoch_time"] = avg_epoch_time + results["avg_checkpoint_time"] = avg_checkpoint_time + #saving the training params including fsdp setting for reference. if train_config.enable_fsdp and not train_config.use_peft: save_train_params(train_config, fsdp_config, rank)