diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index eaaea93606fb5f78b4a0818ed235652230a3d994..e7481769970e6e95525159b5eaeb1a3a4d389b53 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -386,7 +386,9 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, timers('optimizer').stop() if complete or single_step: break - return lm_loss_reduced, skipped_iter, metrics + lm_loss_total /= count + metrics_total = {key: value / count for key, value in metrics_total.items()} + return lm_loss_total, skipped_iter, metrics_total def backward_step(optimizer, model, loss, args, timers): """Backward step."""