diff --git a/utils/train_utils.py b/utils/train_utils.py index 08679a4b867ab728dbc0ec3d7369a3e9c8e77225..95b968ef1794fd79cd24daa8b807e2683a4d66cc 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -153,7 +153,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if eval_epoch_loss < best_val_loss: best_val_loss = eval_epoch_loss - print(f"best eval loss on epoch {epoch} is {best_val_loss}") + if train_config.enable_fsdp: + if rank==0: + print(f"best eval loss on epoch {epoch} is {best_val_loss}") + else: + print(f"best eval loss on epoch {epoch} is {best_val_loss}") val_loss.append(best_val_loss) val_prep.append(eval_ppl)