diff --git a/.gitignore b/.gitignore index 0c04a67d1a543f6fbf0d097ac1287a7f1902aecd..9c4cf7861cbea8df2fca9ea83d77e0f774a6db47 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ .DS_Store __pycache__ .ipynb_checkpoints +.gitignore +wandb/ +artifacts/ diff --git a/src/llama_recipes/configs/__init__.py b/src/llama_recipes/configs/__init__.py index 6aabbb9969cd9cc749fded9b4ac09e3a52be1c46..5db9c216bbe566dbfaef2e05bd76a72c087a450b 100644 --- a/src/llama_recipes/configs/__init__.py +++ b/src/llama_recipes/configs/__init__.py @@ -4,3 +4,4 @@ from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config from llama_recipes.configs.fsdp import fsdp_config from llama_recipes.configs.training import train_config +from llama_recipes.configs.wandb import wandb_config diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 354c534eb067a20fb5485ee7aef3016540b154b7..44f677058c8f882685965d884ac6e9f951face35 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -36,3 +36,4 @@ class train_config: dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + enable_wandb: bool = False # add wandb for experient tracking diff --git a/src/llama_recipes/configs/wandb.py b/src/llama_recipes/configs/wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3422fa288f338fd11c79951cb07b24b3f331a9 --- /dev/null +++ b/src/llama_recipes/configs/wandb.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass, field + +@dataclass +class wandb_config: + wandb_project: str='llama_recipes' # wandb project name + wandb_entity: str='none' # wandb entity name + wandb_log_model: bool=False # whether or not to log model as artifact at the end of training + wandb_watch: str='false' # can be set to 'gradients' or 'all' to log gradients and parameters diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 2ec5c2340fb067136b41128271d55bebf6577be7..306abc8482ce7a978cc4971bc14a3f5436554a06 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -11,6 +11,7 @@ import torch.optim as optim from peft import get_peft_model, prepare_model_for_int8_training from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, + ) from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.optim.lr_scheduler import StepLR @@ -45,12 +46,29 @@ from llama_recipes.utils.train_utils import ( get_policies ) - +def setup_wandb(train_config, fsdp_config, **kwargs): + try: + import wandb + except ImportError: + raise ImportError( + "You are trying to use wandb which is not currently installed" + " Please install it using pip install wandb" + ) + from llama_recipes.configs import wandb_config as WANDB_CONFIG + wandb_config = WANDB_CONFIG() + wandb_entity = None if wandb_config.wandb_entity == 'none' else wandb_config.wandb_entity + update_config(wandb_config, **kwargs) + run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity) + run.config.update(train_config) + run.config.update(fsdp_config) + return run + + def main(**kwargs): # Update the configuration for the training and sharding process train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() update_config((train_config, fsdp_config), **kwargs) - + # Set the seeds for reproducibility torch.cuda.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) @@ -68,6 +86,10 @@ def main(**kwargs): clear_gpu_cache(local_rank) setup_environ_flags(rank) + if train_config.enable_wandb: + if not train_config.enable_fsdp or rank==0: + wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) + # Load the pre-trained model and setup its configuration use_cache = False if train_config.enable_fsdp else None if train_config.enable_fsdp and train_config.low_cpu_fsdp: @@ -89,6 +111,7 @@ def main(**kwargs): device_map="auto" if train_config.quantization else None, use_cache=use_cache, ) + else: llama_config = LlamaConfig.from_pretrained(train_config.model_name) llama_config.use_cache = use_cache @@ -132,6 +155,10 @@ def main(**kwargs): peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) model.print_trainable_parameters() + if train_config.enable_wandb: + if not train_config.enable_fsdp or rank==0: + wandb_run.config.update(peft_config) + #setting up FSDP if enable_fsdp is enabled if train_config.enable_fsdp: @@ -237,9 +264,14 @@ def main(**kwargs): fsdp_config if train_config.enable_fsdp else None, local_rank if train_config.enable_fsdp else None, rank if train_config.enable_fsdp else None, + wandb_run if train_config.enable_wandb else None, ) if not train_config.enable_fsdp or rank==0: [print(f'Key: {k}, Value: {v}') for k, v in results.items()] + if train_config.enable_wandb: + for k,v in results.items(): + wandb_run.summary[k] = v + if __name__ == "__main__": fire.Fire(main) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 7bb759a10b7b10ca7d66d145a3f730b7c4c988a8..5138e580bdb90405e23a3f4862e19900a01af7bf 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -31,7 +31,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer): def byte2mb(x): return int(x / 2**20) -def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None): +def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None): """ Trains the model on the given dataloader @@ -99,6 +99,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche optimizer.zero_grad() pbar.update(1) + if wandb_run: + if not train_config.enable_fsdp or rank==0: + wandb_run.log({ + 'train/epoch': epoch + 1, + 'train/step': epoch * len(train_dataloader) + step, + 'train/loss': loss.detach().float(), + }) + pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})") pbar.close() @@ -133,7 +141,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche lr_scheduler.step() if train_config.run_validation: - eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run) checkpoint_start_time = time.perf_counter() if train_config.save_model and eval_epoch_loss < best_val_loss: if train_config.enable_fsdp: @@ -213,7 +221,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche return results -def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): +def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run): """ Evaluates the model on the given dataloader @@ -266,6 +274,13 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): else: print(f" {eval_ppl=} {eval_epoch_loss=}") + if wandb_run: + if not train_config.enable_fsdp or rank==0: + wandb_run.log({ + 'eval/perplexity': eval_ppl, + 'eval/loss': eval_epoch_loss, + }, commit=False) + return eval_ppl, eval_epoch_loss def freeze_transformer_layers(model, num_layer):