From 26e877fd4278ee38be8d60a2d8daff8c6af73433 Mon Sep 17 00:00:00 2001 From: Kai Wu <kaiwu@meta.com> Date: Mon, 29 Apr 2024 16:36:24 -0700 Subject: [PATCH] changed readme, unified the context interface and added get_flops_per_sec() --- docs/multi_gpu.md | 70 ++++++++++++++----------- docs/single_gpu.md | 71 ++++++++++++++----------- recipes/finetuning/README.md | 72 +++++++++++++++----------- src/llama_recipes/configs/training.py | 4 +- src/llama_recipes/utils/flop_utils.py | 62 ++++++++++++---------- src/llama_recipes/utils/train_utils.py | 16 +++--- 6 files changed, 165 insertions(+), 130 deletions(-) diff --git a/docs/multi_gpu.md b/docs/multi_gpu.md index fd1bf4cd..a6c44cbf 100644 --- a/docs/multi_gpu.md +++ b/docs/multi_gpu.md @@ -115,35 +115,47 @@ torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --m It lets us specify the training settings for everything from `model_name` to `dataset_name`, `batch_size` and so on. Below is the list of supported settings: ```python - -model_name: str="PATH/to/LLAMA 2/7B" -enable_fsdp: bool= False -run_validation: bool=True -batch_size_training: int=4 -gradient_accumulation_steps: int=1 -num_epochs: int=3 -num_workers_dataloader: int=2 -lr: float=2e-4 -weight_decay: float=0.0 -gamma: float= 0.85 -use_fp16: bool=False -mixed_precision: bool=True -val_batch_size: int=4 -dataset = "samsum_dataset" # alpaca_dataset, grammar_dataset -peft_method: str = "lora" # None , llama_adapter, prefix -use_peft: bool=False -output_dir: str = "./ft-output" -freeze_layers: bool = False -num_freeze_layers: int = 1 -quantization: bool = False -save_model: bool = False -dist_checkpoint_root_folder: str="model_checkpoints" -dist_checkpoint_folder: str="fine-tuned" -save_optimizer: bool=False -flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time. -flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS. -use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time. -profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler + model_name: str="PATH/to/Model" + tokenizer_name: str=None + enable_fsdp: bool=False + low_cpu_fsdp: bool=False + run_validation: bool=True + batch_size_training: int=4 + batching_strategy: str="packing" #alternative: padding + context_length: int=4096 + gradient_accumulation_steps: int=1 + gradient_clipping: bool = False + gradient_clipping_threshold: float = 1.0 + num_epochs: int=3 + max_train_step: int=0 + max_eval_step: int=0 + num_workers_dataloader: int=1 + lr: float=1e-4 + weight_decay: float=0.0 + gamma: float= 0.85 + seed: int=42 + use_fp16: bool=False + mixed_precision: bool=True + val_batch_size: int=1 + dataset = "samsum_dataset" + peft_method: str = "lora" # None,llama_adapter, prefix + use_peft: bool=False + output_dir: str = "PATH/to/save/PEFT/model" + freeze_layers: bool = False + num_freeze_layers: int = 1 + quantization: bool = False + one_gpu: bool = False + save_model: bool = True + dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP + save_optimizer: bool=False # will be used if using FSDP + use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + use_wandb: bool = False # Enable wandb for experient tracking + save_metrics: bool = False # saves training metrics to a json file for later plotting + flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time. + flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops. + use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time. + profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler ``` * [Datasets config file](../src/llama_recipes/configs/datasets.py) provides the available options for datasets. diff --git a/docs/single_gpu.md b/docs/single_gpu.md index fa35ca08..65e81df5 100644 --- a/docs/single_gpu.md +++ b/docs/single_gpu.md @@ -71,36 +71,47 @@ python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization It let us specify the training settings, everything from `model_name` to `dataset_name`, `batch_size` etc. can be set here. Below is the list of supported settings: ```python - -model_name: str="PATH/to/LLAMA 2/7B" -enable_fsdp: bool= False -run_validation: bool=True -batch_size_training: int=4 -gradient_accumulation_steps: int=1 -num_epochs: int=3 -num_workers_dataloader: int=2 -lr: float=2e-4 -weight_decay: float=0.0 -gamma: float= 0.85 -use_fp16: bool=False -mixed_precision: bool=True -val_batch_size: int=4 -dataset = "samsum_dataset" # alpaca_dataset,grammar_dataset -peft_method: str = "lora" # None , llama_adapter, prefix -use_peft: bool=False -output_dir: str = "./ft-output" -freeze_layers: bool = False -num_freeze_layers: int = 1 -quantization: bool = False -one_gpu: bool = False -save_model: bool = False -dist_checkpoint_root_folder: str="model_checkpoints" -dist_checkpoint_folder: str="fine-tuned" -save_optimizer: bool=False -flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time. -flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS. -use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time. -profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler + model_name: str="PATH/to/Model" + tokenizer_name: str=None + enable_fsdp: bool=False + low_cpu_fsdp: bool=False + run_validation: bool=True + batch_size_training: int=4 + batching_strategy: str="packing" #alternative: padding + context_length: int=4096 + gradient_accumulation_steps: int=1 + gradient_clipping: bool = False + gradient_clipping_threshold: float = 1.0 + num_epochs: int=3 + max_train_step: int=0 + max_eval_step: int=0 + num_workers_dataloader: int=1 + lr: float=1e-4 + weight_decay: float=0.0 + gamma: float= 0.85 + seed: int=42 + use_fp16: bool=False + mixed_precision: bool=True + val_batch_size: int=1 + dataset = "samsum_dataset" + peft_method: str = "lora" # None,llama_adapter, prefix + use_peft: bool=False + output_dir: str = "PATH/to/save/PEFT/model" + freeze_layers: bool = False + num_freeze_layers: int = 1 + quantization: bool = False + one_gpu: bool = False + save_model: bool = True + dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP + save_optimizer: bool=False # will be used if using FSDP + use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + use_wandb: bool = False # Enable wandb for experient tracking + save_metrics: bool = False # saves training metrics to a json file for later plotting + flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time. + flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops. + use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time. + profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler ``` * [Datasets config file](../src/llama_recipes/configs/datasets.py) provides the available options for datasets. diff --git a/recipes/finetuning/README.md b/recipes/finetuning/README.md index 50ce31f6..22849e31 100644 --- a/recipes/finetuning/README.md +++ b/recipes/finetuning/README.md @@ -23,37 +23,47 @@ If you are new to fine-tuning techniques, check out an overview: [](./LLM_finetu It lets us specify the training settings for everything from `model_name` to `dataset_name`, `batch_size` and so on. Below is the list of supported settings: ```python - -model_name: str="PATH/to/LLAMA 2/7B" -enable_fsdp: bool=False -run_validation: bool=True -batch_size_training: int=4 -gradient_accumulation_steps: int=1 -max_train_step: int=0 -max_eval_step: int=0 -num_epochs: int=3 -num_workers_dataloader: int=2 -lr: float=2e-4 -weight_decay: float=0.0 -gamma: float=0.85 -use_fp16: bool=False -mixed_precision: bool=True -val_batch_size: int=4 -dataset = "samsum_dataset" # alpaca_dataset, grammar_dataset -peft_method: str="lora" # None , llama_adapter, prefix -use_peft: bool=False -output_dir: str="./ft-output" -freeze_layers: bool = False -num_freeze_layers: int = 1 -quantization: bool = False -save_model: bool = False -dist_checkpoint_root_folder: str="model_checkpoints" -dist_checkpoint_folder: str="fine-tuned" -save_optimizer: bool=False -flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time. -flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS. -use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time. -profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler + model_name: str="PATH/to/Model" + tokenizer_name: str=None + enable_fsdp: bool=False + low_cpu_fsdp: bool=False + run_validation: bool=True + batch_size_training: int=4 + batching_strategy: str="packing" #alternative: padding + context_length: int=4096 + gradient_accumulation_steps: int=1 + gradient_clipping: bool = False + gradient_clipping_threshold: float = 1.0 + num_epochs: int=3 + max_train_step: int=0 + max_eval_step: int=0 + num_workers_dataloader: int=1 + lr: float=1e-4 + weight_decay: float=0.0 + gamma: float= 0.85 + seed: int=42 + use_fp16: bool=False + mixed_precision: bool=True + val_batch_size: int=1 + dataset = "samsum_dataset" + peft_method: str = "lora" # None,llama_adapter, prefix + use_peft: bool=False + output_dir: str = "PATH/to/save/PEFT/model" + freeze_layers: bool = False + num_freeze_layers: int = 1 + quantization: bool = False + one_gpu: bool = False + save_model: bool = True + dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP + save_optimizer: bool=False # will be used if using FSDP + use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + use_wandb: bool = False # Enable wandb for experient tracking + save_metrics: bool = False # saves training metrics to a json file for later plotting + flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time. + flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops. + use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time. + profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler ``` * [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets. diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 4b21872c..2073e04e 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -6,7 +6,7 @@ from dataclasses import dataclass @dataclass class train_config: - model_name: str="PATH/to/LLAMA/7B" + model_name: str="PATH/to/Model" tokenizer_name: str=None enable_fsdp: bool=False low_cpu_fsdp: bool=False @@ -29,7 +29,7 @@ class train_config: mixed_precision: bool=True val_batch_size: int=1 dataset = "samsum_dataset" - peft_method: str = "lora" # None , llama_adapter, prefix + peft_method: str = "lora" # None,llama_adapter, prefix use_peft: bool=False output_dir: str = "PATH/to/save/PEFT/model" freeze_layers: bool = False diff --git a/src/llama_recipes/utils/flop_utils.py b/src/llama_recipes/utils/flop_utils.py index 36ad83b0..dcdb28e3 100644 --- a/src/llama_recipes/utils/flop_utils.py +++ b/src/llama_recipes/utils/flop_utils.py @@ -1,5 +1,5 @@ from typing import Any, Dict, List, Optional, Union - +import time import torch from torch.utils.flop_counter import FlopCounterMode @@ -15,14 +15,12 @@ class FlopMeasure(FlopCounterMode): .. code-block:: python - mod = ... - flop_counter = FlopMeasure(mod) + model = ... + flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3) for batch in enumerate(dataloader): with flop_counter: - if step == 3: - flop_counter.start_counting() - mod(batch) - flop_counter.stop_counting() + model(batch) + flop_counter.step() """ def __init__( @@ -32,50 +30,58 @@ class FlopMeasure(FlopCounterMode): display: bool = True, custom_mapping: Dict[Any, Any] = None, rank=None, + warmup_step: int = 3, ): super().__init__(mods, depth, display, custom_mapping) - self.ready = False self.rank = rank + self.warmup_step = warmup_step + self.start_time = 0 + self.end_time = 0 + def step(self): + # decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1. + if self.warmup_step >= 0: + self.warmup_step -= 1 + if self.warmup_step == 0 and self.start_time == 0: + self.start_time = time.time() + elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0: + self.end_time = time.time() def __enter__(self): - self.ready = False + if self.warmup_step == 0: + self.start_time = time.time() super().__enter__() return self - + def is_done(self): + return self.warmup_step == -1 def get_total_flops(self): return super().get_total_flops() - + def get_flops_per_sec(self): + if self.start_time == 0 or self.end_time == 0: + print("Warning: flop count did not finish correctly") + return 0 + return super().get_total_flops()/ (self.end_time - self.start_time) def get_table(self, depth=2): return super().get_table(depth) def __exit__(self, *args): - self.ready = False if self.get_total_flops() == 0: print( "Warning: did not record any flops this time. Skipping the flop report" ) else: - self.stop_counting() if self.display: if self.rank is None or self.rank == 0: - print("self.flop_counts", self.get_total_flops()) + print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time)) + print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12)) + print("The tflop_count table is below:") print(self.get_table(self.depth)) # Disable the display feature so that we don't print the table again self.display = False super().__exit__(*args) - def start_counting(self): - self.ready = True - - def is_ready(self): - return self.ready - - def stop_counting(self): - self.ready = False - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - # return the original output if not ready - if not self.ready: - return func(*args, **kwargs) - # otherwise, count the flops and return the original output - return super().__torch_dispatch__(func, types, args, kwargs) + # when warmup_step is 0, count the flops and return the original output + if self.warmup_step == 0: + return super().__torch_dispatch__(func, types, args, kwargs) + # otherwise, just return the original output + return func(*args, **kwargs) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 06b38918..a71447ea 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -59,9 +59,9 @@ def profile(cfg, local_rank=None): ) as torch_profiler: yield torch_profiler elif use_flop_counter: - if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_start: - raise ValueError(f"flop counter requires at least {cfg.flop_counter_start} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}") - with FlopMeasure(rank=local_rank) as flop_counter: + if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start: + raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}") + with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter: yield flop_counter else: torch_profiler = contextlib.nullcontext() @@ -135,9 +135,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if not train_config.enable_fsdp or local_rank==0: print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1) break - if train_config.flop_counter and total_train_steps == train_config.flop_counter_start: - print("start flop counting at the step: ", total_train_steps) - profile_context.start_counting() for key in batch.keys(): if train_config.enable_fsdp: if is_xpu_available(): @@ -183,11 +180,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche optimizer.step() optimizer.zero_grad() pbar.update(1) - if train_config.use_profiler: + if train_config.use_profiler or train_config.flop_counter: profile_context.step() - if train_config.flop_counter and profile_context.is_ready(): - TFlops = profile_context.get_total_flops() / 1e12 - profile_context.stop_counting() + if train_config.flop_counter and profile_context.is_done(): + TFlops = profile_context.get_flops_per_sec() / 1e12 if wandb_run: if not train_config.enable_fsdp or rank==0: wandb_run.log({ -- GitLab