From 448af9d7c198828b51d5304e247a7393a6802197 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 14 Oct 2024 14:34:02 -0700 Subject: [PATCH] Fix test on non cuda machine --- src/tests/test_batching.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tests/test_batching.py b/src/tests/test_batching.py index 40b915cb..b51e8691 100644 --- a/src/tests/test_batching.py +++ b/src/tests/test_batching.py @@ -92,6 +92,7 @@ def test_packing( @pytest.mark.skip_missing_tokenizer +@patch("llama_recipes.utils.train_utils.torch.cuda.is_bf16_supported") @patch("llama_recipes.finetuning.torch.cuda.is_available") @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.AutoTokenizer') @@ -119,6 +120,7 @@ def test_distributed_packing( tokenizer, train, cuda_is_available, + cuda_is_bf16_supported, setup_tokenizer, setup_processor, llama_version, @@ -133,6 +135,7 @@ def test_distributed_packing( get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0] get_config.return_value = Config(model_type=model_type) cuda_is_available.return_value = False + cuda_is_bf16_supported.return_value = False rank = 1 os.environ['LOCAL_RANK'] = f'{rank}' -- GitLab