diff --git a/utils/train_utils.py b/utils/train_utils.py
index d03f62aa6927834045bab98c9da4a01714c4c2ff..27cd93f0186301eae7bdf9a59a4515a72e2da9e8 100644
--- a/utils/train_utils.py
+++ b/utils/train_utils.py
@@ -253,6 +253,7 @@ def setup_environ_flags(rank):
     os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
     os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
     os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")