From a955ed1999381ae5e10edc3b251a6fa131eb881f Mon Sep 17 00:00:00 2001
From: Hamid Shojanazeri <hamid.nazeri2010@gmail.com>
Date: Sun, 30 Jul 2023 18:00:08 +0000
Subject: [PATCH] added checks for dist barrier and commented cuda exapnadable
 segements and dist_dbug

---
 utils/train_utils.py | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/utils/train_utils.py b/utils/train_utils.py
index 97b729ad..5f4fe151 100644
--- a/utils/train_utils.py
+++ b/utils/train_utils.py
@@ -137,7 +137,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
             if train_config.save_model and eval_epoch_loss < best_val_loss:
-                dist.barrier()
+                if train_config.enable_fsdp:
+                    dist.barrier()
                 if train_config.use_peft:
                     if train_config.enable_fsdp:
                         if rank==0:
@@ -173,7 +174,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         )
                         print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
                         print("=====================================================")                     
-                dist.barrier()
+                if train_config.enable_fsdp:
+                    dist.barrier()
             
             if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
@@ -205,7 +207,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_loss'] = avg_eval_loss
         
-    dist.barrier()
     return results
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
@@ -285,8 +286,10 @@ def setup_environ_flags(rank):
     """Set environment flags for debugging purposes"""
     os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
     os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
-    os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
-    os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
+    # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
+    # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
 
-- 
GitLab