diff --git a/utils/train_utils.py b/utils/train_utils.py index d03f62aa6927834045bab98c9da4a01714c4c2ff..27cd93f0186301eae7bdf9a59a4515a72e2da9e8 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -253,6 +253,7 @@ def setup_environ_flags(rank): os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' if rank == 0: print(f"--> Running with torch dist debug set to detail")