diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 1a9cff9d58ad17e3243df55f5a7eefe11bab55fd..d8e5f307452b08f26ea6f666e036ac49e174e67c 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -38,4 +38,7 @@ class train_config: dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + flop_counter: bool=True #enable flop counter + profiler: bool=True #enable pytorch profiler + profile_output_dir: str="profile_output" save_metrics: bool = False # saves training metrics to a json file for later plotting diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 86a651cc0fdef6857d10497c70b7f19e59f0c44c..775c3ca5063237ee91f5462e55a58f95f866403a 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -3,9 +3,10 @@ import os from pkg_resources import packaging - +import gc import fire import random + import torch import torch.optim as optim from peft import get_peft_model, prepare_model_for_int8_training @@ -44,9 +45,12 @@ from llama_recipes.utils.train_utils import ( print_model_size, get_policies ) + from accelerate.utils import is_xpu_available def main(**kwargs): + gc.disable() + gc.collect(1) # Update the configuration for the training and sharding process train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() update_config((train_config, fsdp_config), **kwargs) @@ -83,11 +87,6 @@ def main(**kwargs): model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms overhead and currently requires latest nightly. """ - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly: - raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - "please install latest nightly.") if rank == 0: model = LlamaForCausalLM.from_pretrained( train_config.model_name, diff --git a/src/llama_recipes/utils/__init__.py b/src/llama_recipes/utils/__init__.py index 6cba8f1ef459094f736a1355cc8030dc1df94353..da73bebef8e312ff380d02983de262a4698a35e0 100644 --- a/src/llama_recipes/utils/__init__.py +++ b/src/llama_recipes/utils/__init__.py @@ -4,4 +4,5 @@ from llama_recipes.utils.memory_utils import MemoryTrace from llama_recipes.utils.dataset_utils import * from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy -from llama_recipes.utils.train_utils import * \ No newline at end of file +from llama_recipes.utils.train_utils import * +from llama_recipes.utils.tflop_counter import * \ No newline at end of file diff --git a/src/llama_recipes/utils/tflop_counter.py b/src/llama_recipes/utils/tflop_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..bdec4753528e9047a5d72f804a3fcc8b235a5fa1 --- /dev/null +++ b/src/llama_recipes/utils/tflop_counter.py @@ -0,0 +1,464 @@ +# Temp copy of Horace Flops Counter. +# This supports distributed to avoid printing * every GPU. +# Remove after main file is updated. + +import torch +from torch.utils._pytree import tree_map +from typing import List, Any, Dict, Optional, Union +from collections import defaultdict +from torch.utils._python_dispatch import TorchDispatchMode +from math import prod + +__all__ = ["FlopCounterMode"] + +aten = torch.ops.aten + + +def get_shape(i): + if isinstance(i, torch.Tensor): + return i.shape + return i + + +def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int: + """ + Count flops for matmul. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two matrices. + m, k = a_shape + k2, n = b_shape + assert k == k2 + # NB(chilli): Should be 2 * k - 1 technically for FLOPs. + return m * n * 2 * k + + +def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: + """ + Count flops for addmm + """ + return mm_flop(a_shape, b_shape) + + +def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int: + """ + Count flops for the bmm operation. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor. + b, m, k = a_shape + b2, k2, n = b_shape + assert b == b2 + assert k == k2 + # NB(chilli): Should be 2 * k - 1 technically for FLOPs. + flop = b * m * n * 2 * k + return flop + + +def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: + """ + Count flops for the baddbmm operation. + """ + # Inputs should be a list of length 3. + # Inputs contains the shapes of three tensors. + return bmm_flop(a_shape, b_shape) + + +def conv_flop_count( + x_shape: List[int], + w_shape: List[int], + out_shape: List[int], + transposed: bool = False, +) -> int: + """ + Count flops for convolution. Note only multiplication is + counted. Computation for bias are ignored. + Flops for a transposed convolution are calculated as + flops = (x_shape[2:] * prod(w_shape) * batch_size). + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + transposed (bool): is the convolution transposed + Returns: + int: the number of flops + """ + batch_size = x_shape[0] + conv_shape = (x_shape if transposed else out_shape)[2:] + c_out, c_in, *dims = w_shape + + # NB(chilli): I don't think this properly accounts for padding :think: + # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs. + flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in + return flop + + +def conv_flop( + x_shape, + w_shape, + _bias, + _stride, + _padding, + _dilation, + transposed, + *args, + out_shape=None, + **kwargs +) -> int: + """ + Count flops for convolution. + """ + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + + +def transpose_shape(shape): + return [shape[1], shape[0]] + list(shape[2:]) + + +def conv_backward_flop( + grad_out_shape, + x_shape, + w_shape, + _bias, + _stride, + _padding, + _dilation, + transposed, + _output_padding, + _groups, + output_mask, + out_shape, +) -> int: + flop_count = 0 + + if output_mask[0]: + grad_input_shape = get_shape(out_shape[0]) + flop_count += conv_flop_count( + grad_out_shape, w_shape, grad_input_shape, not transposed + ) + if output_mask[1]: + grad_weight_shape = get_shape(out_shape[1]) + flop_count += conv_flop_count( + transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed + ) + + return flop_count + + +def sdpa_flop_count(query_shape, key_shape, value_shape): + """ + Count flops for self-attention. + NB: We can assume that value_shape == key_shape + """ + b, h, s_q, d_q = query_shape + _b2, _h2, s_k, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + assert ( + b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2 + ) + total_flops = 0 + # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] + total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k)) + # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v] + total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v)) + return total_flops + + +def sdpa_flop( + query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs +) -> int: + """ + Count flops for self-attention. + """ + # NB: We aren't accounting for causal attention here + return sdpa_flop_count(query_shape, key_shape, value_shape) + + +def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape): + total_flops = 0 + b, h, s_q, d_q = query_shape + _b2, _h2, s_k, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + _b4, _h4, _s4, _d4 = grad_out_shape + assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2 + assert d_v == _d4 and s_k == _s3 and s_q == _s4 + total_flops = 0 + # Step 1: We recompute the scores matrix. + # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] + total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k)) + + # Step 2: We propagate the gradients through the score @ v operation. + # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k] + total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k)) + # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v] + total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v)) + + # Step 3: We propagate th gradients through the k @ v operation + # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q] + total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q)) + # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k] + total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k)) + return total_flops + + +def sdpa_backward_flop( + grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs +) -> int: + """ + Count flops for self-attention backward. + """ + return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) + + +flop_mapping = { + aten.mm: mm_flop, + aten.addmm: addmm_flop, + aten.bmm: bmm_flop, + aten.baddbmm: baddbmm_flop, + aten.convolution: conv_flop, + aten._convolution: conv_flop, + aten.convolution_backward: conv_backward_flop, + aten._scaled_dot_product_efficient_attention: sdpa_flop, + aten._scaled_dot_product_flash_attention: sdpa_flop, + aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop, + aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop, +} + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +# Define the suffixes for different orders of magnitude +suffixes = ["", "K", "M", "B", "T"] + + +# Thanks BingChat! +def get_suffix_str(number): + # Find the index of the appropriate suffix based on the number of digits + # with some additional overflow. + # i.e. 1.01B should be displayed as 1001M, not 1.001B + index = max(0, min(len(suffixes) - 1, (len(str(number)) - 3) // 3)) + return suffixes[index] + + +def convert_num_with_suffix(number, suffix): + index = suffixes.index(suffix) + # Divide the number by 1000^index and format it to two decimal places + value = "{:.3f}".format(number / (1000**index)) + # Return the value and the suffix as a string + return value + suffixes[index] + + +class FlopCounterMode(TorchDispatchMode): + """ + ``FlopCounterMode`` is a context manager that counts the number of + flops within its context. It does this using a ``TorchDispatchMode``. + + It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction. + + Example usage + + .. code-block:: python + + mod = ... + flop_counter = FlopCounterMode(mod) + with flop_counter: + mod.sum().backward() + + """ + + def __init__( + self, + mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None, + depth: int = 2, + display: bool = True, + custom_mapping: Dict[Any, Any] = None, + rank=None, + ): + self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict( + lambda: defaultdict(int) + ) + self.depth = depth + self.parents = ["Global"] + self.display = display + self.rank = rank + + if custom_mapping is None: + custom_mapping = {} + if isinstance(mods, torch.nn.Module): + mods = [mods] + self.mods = mods + if mods is not None: + for mod in mods: + prefix = type(mod).__name__ + for name, module in dict(mod.named_modules()).items(): + if name == "": + name = prefix + else: + name = ".".join([prefix, name]) + module.register_forward_pre_hook(self._enter_module(name)) + module.register_forward_hook(self._exit_module(name)) + self.flop_mapping = {**flop_mapping, **custom_mapping} + + def _enter_module(self, name): + def f(module, inputs): + inputs = normalize_tuple(inputs) + out = self._create_pre_module(name)(*inputs) + return out + + return f + + def _exit_module(self, name): + def f(module, inputs, outputs): + outputs = normalize_tuple(outputs) + return self._create_post_module(name)(*outputs) + + return f + + def _create_post_module(self, name): + class PushState(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + assert self.parents[-1] == name + self.parents.pop() + args = tree_map( + lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args + ) + if len(args) == 1: + return args[0] + return args + + @staticmethod + def backward(ctx, *grad_outs): + self.parents.append(name) + return grad_outs + + return PushState.apply + + def _create_pre_module(self, name): + class PopState(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + self.parents.append(name) + args = tree_map( + lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args + ) + if len(args) == 1: + return args[0] + return args + + @staticmethod + def backward(ctx, *grad_outs): + assert self.parents[-1] == name + self.parents.pop() + return grad_outs + + return PopState.apply + + def get_total_flops(self) -> int: + return sum(self.flop_counts["Global"].values()) + + def get_flop_counts(self) -> Dict[str, Dict[Any, int]]: + """Returns the flop counts as a dictionary of dictionaries. The outer + dictionary is keyed by module name, and the inner dictionary is keyed by + operation name. + + Returns: + Dict[str, Dict[Any, int]]: The flop counts as a dictionary. + """ + return dict(self.flop_counts) + + def get_table(self, depth=None): + if depth is None: + depth = self.depth + if depth is None: + depth = 999999 + + import tabulate + + tabulate.PRESERVE_WHITESPACE = True + header = ["Module", "FLOP", "% Total"] + values = [] + global_flops = self.get_total_flops() + global_suffix = get_suffix_str(global_flops) + is_global_subsumed = False + + def process_mod(mod_name, depth): + nonlocal is_global_subsumed + + total_flops = sum(self.flop_counts[mod_name].values()) + + is_global_subsumed |= total_flops >= global_flops + + padding = " " * depth + values = [] + values.append( + [ + padding + mod_name, + convert_num_with_suffix(total_flops, global_suffix), + "{:.2f}%".format(total_flops / global_flops * 100), + ] + ) + for k, v in self.flop_counts[mod_name].items(): + values.append( + [ + padding + " - " + str(k), + convert_num_with_suffix(v, global_suffix), + "{:.2f}%".format(v / global_flops * 100), + ] + ) + return values + + for mod in self.flop_counts.keys(): + if mod == "Global": + continue + mod_depth = mod.count(".") + 1 + if mod_depth > depth: + continue + + cur_values = process_mod(mod, mod_depth - 1) + for value in cur_values: + values.append(value) + + # We do a bit of messing around here to only output the "Global" value + # if there are any FLOPs in there that aren't already fully contained by + # a module. + if "Global" in self.flop_counts and not is_global_subsumed: + for idx, value in enumerate(values): + values[idx][0] = " " + values[idx][0] + + values = process_mod("Global", 0) + values + + if len(values) == 0: + values = [["Global", "0", "0%"]] + + return tabulate.tabulate( + values, headers=header, colalign=("left", "right", "right") + ) + + def __enter__(self): + self.flop_counts.clear() + super().__enter__() + return self + + def __exit__(self, *args): + if self.display: + if self.rank is None or self.rank == 0: + print(self.get_table(self.depth)) + super().__exit__(*args) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + func_packet = func._overloadpacket + if func_packet in self.flop_mapping: + flop_count_func = self.flop_mapping[func_packet] + args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out)) + flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator] + for par in self.parents: + self.flop_counts[par][func_packet] += flop_count + + return out \ No newline at end of file diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 7dd55957740938556cb61ee070df4e1ea53a2d50..451a4b3fb6093208aad082851bc862b7cc9fc79a 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -7,9 +7,10 @@ import yaml from contextlib import nullcontext from pathlib import Path from pkg_resources import packaging +import contextlib +import gc from datetime import datetime - import torch import torch.cuda.nccl as nccl import torch.distributed as dist @@ -23,8 +24,39 @@ import json from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper from llama_recipes.utils.memory_utils import MemoryTrace + +from llama_recipes.utils.tflop_counter import FlopCounterMode + +@contextlib.contextmanager +def maybe_run_profiler(cfg, *args, **kwargs): + use_profiler: bool = cfg.profiler + + if use_profiler: + print(f"profiling is activated and results will be saved in {cfg.profile_output_dir}") + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + cfg.profile_output_dir + ), + profile_memory=True, + with_stack=False, + record_shapes=True, + ) as torch_profiler: + yield torch_profiler + else: + torch_profiler = contextlib.nullcontext() + yield None + +def get_total_flops(model): + return (sum([v for _, v in model.flop_counts["Global"].items()])) + from accelerate.utils import is_xpu_available, is_ccl_available + def set_tokenizer_params(tokenizer: LlamaTokenizer): tokenizer.pad_token_id = 0 tokenizer.padding_side = "left" @@ -86,6 +118,62 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche total_loss = 0.0 total_length = len(train_dataloader)//gradient_accumulation_steps pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) + + with maybe_run_profiler(train_config) as torch_profiler: + for step, batch in enumerate(train_dataloader): + gc.collect(1) + for key in batch.keys(): + if train_config.enable_fsdp: + batch[key] = batch[key].to(local_rank) + else: + batch[key] = batch[key].to('cuda:0') + flop_check_done = False + if train_config.flop_counter and step == 3 and not flop_check_done: + flop_counter = FlopCounterMode(rank=local_rank) + with flop_counter: + loss = model(**batch).loss + loss = loss / gradient_accumulation_steps + total_loss += loss.detach().float() + if train_config.use_fp16: + # if fp16 is enabled, use gradient scaler to handle gradient update + scaler.scale(loss).backward() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + pbar.update(1) + else: + # regular backpropagation when fp16 is not used + loss.backward() + TFlops = get_total_flops(flop_counter) / 1e12 + flop_check_done = True + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + else: + loss = model(**batch).loss + loss = loss / gradient_accumulation_steps + total_loss += loss.detach().float() + if train_config.use_fp16: + # if fp16 is enabled, use gradient scaler to handle gradient update + scaler.scale(loss).backward() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + pbar.update(1) + else: + # regular backpropagation when fp16 is not used + loss.backward() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})") + pbar.close() + for step, batch in enumerate(train_dataloader): for key in batch.keys(): if train_config.enable_fsdp: @@ -139,6 +227,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep) pbar.close() + epoch_end_time = time.perf_counter()-epoch_start_time epoch_times.append(epoch_end_time) # Reducing total_loss across all devices if there's more than one CUDA device @@ -266,6 +355,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche results['avg_eval_loss'] = avg_eval_loss results["avg_epoch_time"] = avg_epoch_time results["avg_checkpoint_time"] = avg_checkpoint_time + + if train_config.flop_counter: + results["model_flops"]= TFlops + if train_config.save_metrics: results["metrics_filename"] = metrics_filename @@ -296,6 +389,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): eval_loss = 0.0 # Initialize evaluation loss with MemoryTrace() as memtrace: for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): + gc.collect(1) for key in batch.keys(): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) @@ -501,4 +595,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ "val_epoch_perplexity": val_epoch_ppl } with open(output_filename, "w") as f: - json.dump(metrics_data, f) \ No newline at end of file + json.dump(metrics_data, f)