diff --git a/README.md b/README.md
index 099bfbab7a1e0df67fea946ae04574b9ab45bd66..22ae720d98d7f7819d57650865dd5caddfbc83b0 100644
--- a/README.md
+++ b/README.md
@@ -155,7 +155,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 ```
 
diff --git a/docs/inference.md b/docs/inference.md
index 6c841475f51caa2ca756e3dd3a4bae9d40e15958..6b3e5326840ac6e45e7d6d9945be5444738efc7e 100644
--- a/docs/inference.md
+++ b/docs/inference.md
@@ -33,11 +33,11 @@ Currently pad token by default in [HuggingFace Tokenizer is `None`](https://gith
 ```python
 tokenizer.add_special_tokens(
         {
-         
+
             "pad_token": "<PAD>",
         }
     )
-model.resize_token_embeddings(model.config.vocab_size + 1) 
+model.resize_token_embeddings(model.config.vocab_size + 1)
 ```
 Padding would be required for batch inference. In this this [example](../examples/inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference.
 
@@ -99,7 +99,7 @@ In case you have fine-tuned your model with pure FSDP and saved the checkpoints
 This is helpful if you have fine-tuned you model using FSDP only as follows:
 
 ```bash
-torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 
+torchrun --nnodes 1 --nproc_per_node 8  examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16
 ```
 Then convert your FSDP checkpoint to HuggingFace checkpoints using:
 ```bash
@@ -116,6 +116,18 @@ python examples/inference.py --model_name <training_config.output_dir> --prompt_
 
 ```
 
+## Prompt Llama 2
+
+As outlined by [this blog by Hugging Face](https://huggingface.co/blog/llama2#how-to-prompt-llama-2), you can use the template below to prompt Llama 2 chat models. Review the [blog article](https://huggingface.co/blog/llama2#how-to-prompt-llama-2) for more information.
+
+```
+<s>[INST] <<SYS>>
+{{ system_prompt }}
+<</SYS>>
+
+{{ user_message }} [/INST]
+
+```
 
 ## Other Inference Options
 
diff --git a/docs/multi_gpu.md b/docs/multi_gpu.md
index f19c4a8ca8d9115af3611a0e1c1a377b58571c37..50450e6abdead0b5b9572d7fc43fd7eb90ca845b 100644
--- a/docs/multi_gpu.md
+++ b/docs/multi_gpu.md
@@ -62,7 +62,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 ```
 
@@ -120,6 +120,7 @@ model_name: str="PATH/to/LLAMA 2/7B"
 enable_fsdp: bool= False
 run_validation: bool=True
 batch_size_training: int=4
+gradient_accumulation_steps: int=1
 num_epochs: int=3
 num_workers_dataloader: int=2
 lr: float=2e-4
@@ -129,7 +130,6 @@ use_fp16: bool=False
 mixed_precision: bool=True
 val_batch_size: int=4
 dataset = "samsum_dataset" # alpaca_dataset, grammar_dataset
-micro_batch_size: int=1
 peft_method: str = "lora" # None , llama_adapter, prefix
 use_peft: bool=False
 output_dir: str = "./ft-output"
diff --git a/docs/single_gpu.md b/docs/single_gpu.md
index 2131ee6ed86316bd472e5c86740ee437eda4ecf7..5b183094f2dcd4ba2a4c76d1a6ac3803e1e8bfba 100644
--- a/docs/single_gpu.md
+++ b/docs/single_gpu.md
@@ -76,6 +76,7 @@ model_name: str="PATH/to/LLAMA 2/7B"
 enable_fsdp: bool= False
 run_validation: bool=True
 batch_size_training: int=4
+gradient_accumulation_steps: int=1
 num_epochs: int=3
 num_workers_dataloader: int=2
 lr: float=2e-4
@@ -85,7 +86,6 @@ use_fp16: bool=False
 mixed_precision: bool=True
 val_batch_size: int=4
 dataset = "samsum_dataset" # alpaca_dataset,grammar_dataset
-micro_batch_size: int=1
 peft_method: str = "lora" # None , llama_adapter, prefix
 use_peft: bool=False
 output_dir: str = "./ft-output"
diff --git a/examples/chat_completion/chat_completion.py b/examples/chat_completion/chat_completion.py
index 1f28afa0e97f9ae1221b7aa32462ad1028ffce98..e8665cd948201ed2f35970234b613a4d0739448d 100644
--- a/examples/chat_completion/chat_completion.py
+++ b/examples/chat_completion/chat_completion.py
@@ -107,7 +107,7 @@ def main(
             tokens= tokens.unsqueeze(0)
             tokens= tokens.to("cuda:0")
             outputs = model.generate(
-                tokens,
+                input_ids=tokens,
                 max_new_tokens=max_new_tokens,
                 do_sample=do_sample,
                 top_p=top_p,
diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py
index 8062fda5990e5e454e272f868b553274edcb0b03..53148773e13be45f680936a192208cf57edaa2bf 100644
--- a/src/llama_recipes/configs/training.py
+++ b/src/llama_recipes/configs/training.py
@@ -11,6 +11,7 @@ class train_config:
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
+    gradient_accumulation_steps: int=1
     num_epochs: int=3
     num_workers_dataloader: int=1
     lr: float=1e-4
@@ -21,7 +22,6 @@ class train_config:
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"
-    micro_batch_size: int=4
     peft_method: str = "lora" # None , llama_adapter, prefix
     use_peft: bool=False
     output_dir: str = "PATH/to/save/PEFT/model"
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 8b6db344b17f9e9d8fa7af7bddd6677bdb344cac..3043ab73ae5a1227fa5cf14913e57bbf6464973d 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -64,9 +64,6 @@ def main(**kwargs):
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
 
-    # Calculate gradient accumulation steps
-    gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
-
     # Load the pre-trained model and setup its configuration
     if train_config.enable_fsdp and train_config.low_cpu_fsdp:
         """
@@ -240,7 +237,7 @@ def main(**kwargs):
         tokenizer,
         optimizer,
         scheduler,
-        gradient_accumulation_steps,
+        train_config.gradient_accumulation_steps,
         train_config,
         fsdp_config if train_config.enable_fsdp else None,
         local_rank if train_config.enable_fsdp else None,
diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index 24366db9eba8101cf42aa3b52be9ec3a4a728682..66dd720e522c748216c24f56322f2a3daf09a1d2 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -68,7 +68,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
-            for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
+            total_length = len(train_dataloader)//gradient_accumulation_steps
+            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
+            for step, batch in enumerate(train_dataloader):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
@@ -84,17 +86,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
+                        pbar.update(step//gradient_accumulation_steps)
                 else:
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()
-                if train_config.enable_fsdp:
-                    if rank==0:       
-                        print(f"\n step {step} is completed and loss is {loss.detach().float()}")
-                else:
-                    print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                        pbar.update(step//gradient_accumulation_steps)
+                
+                pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+                
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)    
         # Reducing total_loss across all devices if there's more than one CUDA device