diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 92a0317a9b596a112281112529c2cd0a0b315228..e72eea533dcc3979bfb2ec4982ae0c3ffce1c52e 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -169,7 +169,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche lr_scheduler.step() if train_config.run_validation: - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run) if train_config.save_metrics: val_step_loss.extend(temp_val_loss) val_step_perplexity.extend(temp_step_perplexity) @@ -492,4 +492,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ "val_epoch_perplexity": val_epoch_ppl } with open(output_filename, "w") as f: - json.dump(metrics_data, f) \ No newline at end of file + json.dump(metrics_data, f)