Skip to content
Snippets Groups Projects
Commit aff0493f authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Fix train_step while gradient accumulation

parent 653ab4db
No related branches found
No related tags found
No related merge requests found
...@@ -386,7 +386,9 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, ...@@ -386,7 +386,9 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
timers('optimizer').stop() timers('optimizer').stop()
if complete or single_step: if complete or single_step:
break break
return lm_loss_reduced, skipped_iter, metrics lm_loss_total /= count
metrics_total = {key: value / count for key, value in metrics_total.items()}
return lm_loss_total, skipped_iter, metrics_total
def backward_step(optimizer, model, loss, args, timers): def backward_step(optimizer, model, loss, args, timers):
"""Backward step.""" """Backward step."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment