From c33ea3cacb607a8b3cf9090b756504f869db8b9c Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Mon, 11 Sep 2023 07:19:10 -0700
Subject: [PATCH] Fix pbar update

---
 src/llama_recipes/utils/train_utils.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index cd5bf2c0..2f8faaee 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -86,15 +86,15 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
-                        pbar.update(gradient_accumulation_steps)
+                        pbar.update(1)
                 else:
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()
-                        pbar.update(gradient_accumulation_steps)
-                
+                        pbar.update(1)
+
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
                 
-- 
GitLab