diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 711b1dadef16459803c2f70728c9acb04f57b51f..2a15b84f1f8fa0d405bb6d84e0ee495c038bb4ed 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -50,7 +50,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs): import wandb except ImportError: raise ImportError( - "You are trying to use wandb which is not currently installed" + "You are trying to use wandb which is not currently installed. " "Please install it using pip install wandb" ) from llama_recipes.configs import wandb_config as WANDB_CONFIG @@ -59,7 +59,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs): update_config(wandb_config, **kwargs) run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity) run.config.update(train_config) - run.config.update(fsdp_config) + run.config.update(fsdp_config, allow_val_change=True) return run @@ -84,6 +84,8 @@ def main(**kwargs): clear_gpu_cache(local_rank) setup_environ_flags(rank) + wandb_run = None + if train_config.enable_wandb: if not train_config.enable_fsdp or rank==0: wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) @@ -152,9 +154,8 @@ def main(**kwargs): peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) model.print_trainable_parameters() - if train_config.enable_wandb: - if not train_config.enable_fsdp or rank==0: - wandb_run.config.update(peft_config) + if wandb_run: + wandb_run.config.update(peft_config) #setting up FSDP if enable_fsdp is enabled if train_config.enable_fsdp: @@ -260,7 +261,7 @@ def main(**kwargs): fsdp_config if train_config.enable_fsdp else None, local_rank if train_config.enable_fsdp else None, rank if train_config.enable_fsdp else None, - wandb_run if train_config.enable_wandb else None, + wandb_run, ) if not train_config.enable_fsdp or rank==0: [print(f'Key: {k}, Value: {v}') for k, v in results.items()] diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 5138e580bdb90405e23a3f4862e19900a01af7bf..69da5f56b263f0853cb75b0aaa38ed46f3bf7458 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -275,11 +275,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb print(f" {eval_ppl=} {eval_epoch_loss=}") if wandb_run: - if not train_config.enable_fsdp or rank==0: - wandb_run.log({ - 'eval/perplexity': eval_ppl, - 'eval/loss': eval_epoch_loss, - }, commit=False) + wandb_run.log({ + 'eval/perplexity': eval_ppl, + 'eval/loss': eval_epoch_loss, + }, commit=False) return eval_ppl, eval_epoch_loss