diff --git a/llama_finetuning.py b/llama_finetuning.py
index 4e324140781dabc6fd31a8915fe6293951346b9f..ccf8c68457d13597cfe96768b6f8808b94fe243b 100644
--- a/llama_finetuning.py
+++ b/llama_finetuning.py
@@ -134,7 +134,7 @@ def main(**kwargs):
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
-            limit_all_gathers=False,
+            limit_all_gathers=True,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)
diff --git a/model_checkpointing/checkpoint_handler.py b/model_checkpointing/checkpoint_handler.py
index e917c7f2daf234d0adcafea55f4ac284428afe9f..51193e80c063adb7812edd15acda586740fa2efc 100644
--- a/model_checkpointing/checkpoint_handler.py
+++ b/model_checkpointing/checkpoint_handler.py
@@ -212,7 +212,7 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
 
 
 def load_optimizer_checkpoint(model, optimizer, rank, cfg):
-    """load an fdsp optimizer full_state checkpoint using scatter method
+    """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     """
 
diff --git a/policies/activation_checkpointing_functions.py b/policies/activation_checkpointing_functions.py
index 0a1e31f427d1bedc6e7b3eb905e6614f2441be87..379bc6bfabc03f34745d995c3257181aa8ae030f 100644
--- a/policies/activation_checkpointing_functions.py
+++ b/policies/activation_checkpointing_functions.py
@@ -26,7 +26,7 @@ def apply_fsdp_checkpointing(model):
     """apply activation checkpointing to model
     returns None as model is updated directly
     """
-    print(f"--> applying fdsp activation checkpointing...")
+    print(f"--> applying fsdp activation checkpointing...")
 
     apply_activation_checkpointing(
         model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
diff --git a/utils/train_utils.py b/utils/train_utils.py
index e41f503ea1ba7140f84b202a4ecd72c88ee0285b..7421907585ba912e4afff3bf59cc42f45ff0ff06 100644
--- a/utils/train_utils.py
+++ b/utils/train_utils.py
@@ -84,9 +84,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda:0')       
-                outputs = model(**batch)
-                loss = outputs.loss
+
+                        batch[key] = batch[key].to('cuda:0')              
+                loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
                 first_key = next(iter(batch))
@@ -105,7 +105,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         optimizer.step()
                         optimizer.zero_grad()
                         
-                print(f"\n step {step} is completed and loss is {loss.detach().float()}")        
+                print(f"\n step {step} is completed and loss is {loss.detach().float()}")
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)