diff --git a/recipes/finetuning/README.md b/recipes/finetuning/README.md index 7eafe88e669537e0eaf73b2c57c91e6fd3aa643c..dd99a91c6b60a0d7cefbcda5b974844321916f28 100644 --- a/recipes/finetuning/README.md +++ b/recipes/finetuning/README.md @@ -50,9 +50,9 @@ save_model: bool = False dist_checkpoint_root_folder: str="model_checkpoints" dist_checkpoint_folder: str="fine-tuned" save_optimizer: bool=False -flop_counter: bool=False # Enable Flop counter to measure model throughput, can not be used with pytorch profiler at the same time. -flop_counter_startpoint: int=3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops. -use_profiler: bool=False # Enable pytorch profiler, can not be used with flop counter at the same time. +flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time. +flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS. +use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time. profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler ``` @@ -94,8 +94,8 @@ You'll be able to access a dedicated project or run link on [wandb.ai](https://w <img src="../../docs/images/wandb_screenshot.png" alt="wandb screenshot" width="500" /> </div> -## FLop Counting and Pytorch Profiling +## FLOPS Counting and Pytorch Profiling -To help with benchmarking effort, we are adding the support for counting the flops during the fine-tuning process. You can achieve this by setting `--flop_counter` when launching your single/multi GPU fine-tuning. Use `--flop_counter_startpoint` to choose which step to count the flops. It is recommended to allow a warmup stage before using the flop counter. +To help with benchmarking effort, we are adding the support for counting the FLOPS during the fine-tuning process. You can achieve this by setting `--flop_counter` when launching your single/multi GPU fine-tuning. Use `--flop_counter_start` to choose which step to count the FLOPS. It is recommended to allow a warm-up stage before using the FLOPS counter. -Similarly, you can set `--use_profiler` flag and pass a profiling output path using `--profiler_dir` to capture the profile traces of your model using [PyTorch profiler](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). This would be helpful for debugging purposes. However, the `--flop_counter` and `--use_profiler` can not be used in the same time to ensure the measurement accuarcy. +Similarly, you can set `--use_profiler` flag and pass a profiling output path using `--profiler_dir` to capture the profile traces of your model using [PyTorch profiler](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). This would be helpful for debugging purposes. However, the `--flop_counter` and `--use_profiler` can not be used in the same time to ensure the measurement accuracy. diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index dcc80f3e376cd8686d8dc98e26493beecba4ac9e..eac8d1980c597cc6d631444e67181cc872cef3be 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -42,7 +42,7 @@ class train_config: use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels use_wandb: bool = False # Enable wandb for experient tracking save_metrics: bool = False # saves training metrics to a json file for later plotting - flop_counter: bool = False # Enable Flop counter to measure model throughput, can not be used with pytorch profiler at the same time. - flop_counter_startpoint: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops. + flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time. + flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops. use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time. profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index e90b03dfc9c0070de60a5ed849378b9bdb9386f9..213653698beaa72ed873aa2252a92dc8568c143f 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -59,8 +59,8 @@ def throughput_measure_context(cfg, local_rank=None): ) as torch_profiler: yield torch_profiler elif use_flop_counter: - if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_startpoint: - raise ValueError(f"flop counter requires at least {cfg.flop_counter_startpoint} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}") + if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_start: + raise ValueError(f"flop counter requires at least {cfg.flop_counter_start} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}") with FlopMeasure(rank=local_rank) as flop_counter: yield flop_counter else: @@ -136,7 +136,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if not train_config.enable_fsdp or local_rank==0: print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1) break - if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint: + if train_config.flop_counter and total_train_steps == train_config.flop_counter_start: print("start flop counting at the step: ", total_train_steps) measure_context.start_counting() for key in batch.keys():