diff --git a/.gitignore b/.gitignore index 0c04a67d1a543f6fbf0d097ac1287a7f1902aecd..3ee7b311c33276847a51db96f0fc3d8938cba9f3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .DS_Store __pycache__ .ipynb_checkpoints +wandb/ +artifacts/ diff --git a/README.md b/README.md index 64f0cb02ce5da4d78d178066ffc4d2614e8b8952..c418e186a314ce55079a4e2639f4f672dbd68145 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,18 @@ sbatch multi_node.slurm ``` You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md). +## Weights & Biases Experiment Tracking + +You can enable [W&B](https://wandb.ai/) experiment tracking by using `use_wandb` flag as below. You can change the project name, entity and other `wandb.init` arguments in `wandb_config`. + +```bash +python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model --use_wandb +``` +You'll be able to access a dedicated project or run link on [wandb.ai](https://wandb.ai) and see your dashboard like the one below. +<div style="display: flex;"> + <img src="./docs/images/wandb_screenshot.png" alt="wandb screenshot" width="500" /> +</div> + # Evaluation Harness Here, we make use `lm-evaluation-harness` from `EleutherAI` for evaluation of fine-tuned Llama 2 models. This also can extend to evaluate other optimizations for inference of Llama 2 model such as quantization. Please use this get started [doc](./eval/README.md). @@ -234,7 +246,7 @@ This folder contains a series of benchmark scripts for Llama 2 models inference This repository is organized in the following way: [benchmarks](./benchmarks): Contains a series of benchmark scripts for Llama 2 models inference on various backends. -[configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets. +[configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets, Weights & Biases experiment tracking. [docs](docs/): Example recipes for single and multi-gpu fine-tuning recipes. diff --git a/docs/images/wandb_screenshot.png b/docs/images/wandb_screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..4cbab343a4c02778bd5699a9c595133a318e650b Binary files /dev/null and b/docs/images/wandb_screenshot.png differ diff --git a/scripts/spellcheck_conf/wordlist.txt b/scripts/spellcheck_conf/wordlist.txt index afb395152295875c1ade37edab2b66d47e620f39..2254328d1fa81469611f5eb82a6edbf9332cb302 100644 --- a/scripts/spellcheck_conf/wordlist.txt +++ b/scripts/spellcheck_conf/wordlist.txt @@ -1251,3 +1251,10 @@ lm prepended subtasks EleutherAI +CodeLlama +LlamaGuard +OctoAI +OctoAI's +PurpleLlama +Youtube +wandb diff --git a/src/llama_recipes/configs/__init__.py b/src/llama_recipes/configs/__init__.py index 6aabbb9969cd9cc749fded9b4ac09e3a52be1c46..5db9c216bbe566dbfaef2e05bd76a72c087a450b 100644 --- a/src/llama_recipes/configs/__init__.py +++ b/src/llama_recipes/configs/__init__.py @@ -4,3 +4,4 @@ from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config from llama_recipes.configs.fsdp import fsdp_config from llama_recipes.configs.training import train_config +from llama_recipes.configs.wandb import wandb_config diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 1a9cff9d58ad17e3243df55f5a7eefe11bab55fd..844eb18749a38f347e8cf3ddf2d60b3e559f81f6 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -38,4 +38,5 @@ class train_config: dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + use_wandb: bool = False # Enable wandb for experient tracking save_metrics: bool = False # saves training metrics to a json file for later plotting diff --git a/src/llama_recipes/configs/wandb.py b/src/llama_recipes/configs/wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..6a43ffec283006c9df44d0f4ec6397f5c57b4d9c --- /dev/null +++ b/src/llama_recipes/configs/wandb.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from typing import List, Optional +from dataclasses import dataclass, field + +@dataclass +class wandb_config: + project: str = 'llama_recipes' # wandb project name + entity: Optional[str] = None # wandb entity name + job_type: Optional[str] = None + tags: Optional[List[str]] = None + group: Optional[str] = None + notes: Optional[str] = None + mode: Optional[str] = None \ No newline at end of file diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index cddfe45d1102d9a9c995593cf8a5921c260e8943..6b5650b20534f57bf3fd7d8873b8741414b050c8 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -4,6 +4,7 @@ import os from pkg_resources import packaging +import dataclasses import fire import random import torch @@ -49,11 +50,28 @@ from llama_recipes.utils.train_utils import ( ) from accelerate.utils import is_xpu_available +def setup_wandb(train_config, fsdp_config, **kwargs): + try: + import wandb + except ImportError: + raise ImportError( + "You are trying to use wandb which is not currently installed. " + "Please install it using pip install wandb" + ) + from llama_recipes.configs import wandb_config as WANDB_CONFIG + wandb_config = WANDB_CONFIG() + update_config(wandb_config, **kwargs) + init_dict = dataclasses.asdict(wandb_config) + run = wandb.init(**init_dict) + run.config.update(train_config) + run.config.update(fsdp_config, allow_val_change=True) + return run + + def main(**kwargs): # Update the configuration for the training and sharding process train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() update_config((train_config, fsdp_config), **kwargs) - # Set the seeds for reproducibility if is_xpu_available(): torch.xpu.manual_seed(train_config.seed) @@ -75,6 +93,12 @@ def main(**kwargs): clear_gpu_cache(local_rank) setup_environ_flags(rank) + wandb_run = None + + if train_config.use_wandb: + if not train_config.enable_fsdp or rank==0: + wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) + # Load the pre-trained model and setup its configuration use_cache = False if train_config.enable_fsdp else None if train_config.enable_fsdp and train_config.low_cpu_fsdp: @@ -130,6 +154,9 @@ def main(**kwargs): peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) model.print_trainable_parameters() + if wandb_run: + wandb_run.config.update(peft_config) + hsdp_device_mesh = None if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD: @@ -250,9 +277,13 @@ def main(**kwargs): fsdp_config if train_config.enable_fsdp else None, local_rank if train_config.enable_fsdp else None, rank if train_config.enable_fsdp else None, + wandb_run, ) if not train_config.enable_fsdp or rank==0: [print(f'Key: {k}, Value: {v}') for k, v in results.items()] + if train_config.use_wandb: + for k,v in results.items(): + wandb_run.summary[k] = v if __name__ == "__main__": fire.Fire(main) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d9d7cbb5c17576ad9dcbe2c6c15e01b55276107d..e72eea533dcc3979bfb2ec4982ae0c3ffce1c52e 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -33,7 +33,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer): def byte2mb(x): return int(x / 2**20) -def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None): +def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None): """ Trains the model on the given dataloader @@ -133,6 +133,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche optimizer.zero_grad() pbar.update(1) + if wandb_run: + if not train_config.enable_fsdp or rank==0: + wandb_run.log({ + 'train/epoch': epoch + 1, + 'train/step': epoch * len(train_dataloader) + step, + 'train/loss': loss.detach().float(), + }) + pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})") if train_config.save_metrics: @@ -161,7 +169,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche lr_scheduler.step() if train_config.run_validation: - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run) if train_config.save_metrics: val_step_loss.extend(temp_val_loss) val_step_perplexity.extend(temp_step_perplexity) @@ -252,7 +260,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche return results -def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): +def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run): """ Evaluates the model on the given dataloader @@ -315,6 +323,12 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): print(f" {eval_ppl=} {eval_epoch_loss=}") else: print(f" {eval_ppl=} {eval_epoch_loss=}") + + if wandb_run: + wandb_run.log({ + 'eval/perplexity': eval_ppl, + 'eval/loss': eval_epoch_loss, + }, commit=False) return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity @@ -478,4 +492,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ "val_epoch_perplexity": val_epoch_ppl } with open(output_filename, "w") as f: - json.dump(metrics_data, f) \ No newline at end of file + json.dump(metrics_data, f)