From 448af9d7c198828b51d5304e247a7393a6802197 Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Mon, 14 Oct 2024 14:34:02 -0700
Subject: [PATCH] Fix test on non cuda machine

---
 src/tests/test_batching.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/src/tests/test_batching.py b/src/tests/test_batching.py
index 40b915cb..b51e8691 100644
--- a/src/tests/test_batching.py
+++ b/src/tests/test_batching.py
@@ -92,6 +92,7 @@ def test_packing(
 
 
 @pytest.mark.skip_missing_tokenizer
+@patch("llama_recipes.utils.train_utils.torch.cuda.is_bf16_supported")
 @patch("llama_recipes.finetuning.torch.cuda.is_available")
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
@@ -119,6 +120,7 @@ def test_distributed_packing(
     tokenizer,
     train,
     cuda_is_available,
+    cuda_is_bf16_supported,
     setup_tokenizer,
     setup_processor,
     llama_version,
@@ -133,6 +135,7 @@ def test_distributed_packing(
     get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
     get_config.return_value = Config(model_type=model_type)
     cuda_is_available.return_value = False
+    cuda_is_bf16_supported.return_value = False
 
     rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
-- 
GitLab