diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index f62da5bc9b1a2f5039aced1062b5a1c65b513e8e..ab85e3f425e58a98359522727de632a96b1d56b4 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -6,7 +6,6 @@ import time
 import yaml
 from contextlib import nullcontext
 from pathlib import Path
-from pkg_resources import packaging
 from datetime import datetime
 import contextlib
 
@@ -474,7 +473,7 @@ def get_policies(cfg, rank):
     verify_bfloat_support = ((
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
-    and packaging.version.parse(torch.version.cuda).release >= (11, 0)
+    and torch.version.cuda >= "11.0"
     and dist.is_nccl_available()
     and nccl.version() >= (2, 10)
     ) or