diff --git a/src/tests/datasets/test_samsum_datasets.py b/src/tests/datasets/test_samsum_datasets.py index 8e3c2795e378b32980358fbd73e1df02644a901c..3a71059daa68bf89920714eb67b6f08b0c5c0e51 100644 --- a/src/tests/datasets/test_samsum_datasets.py +++ b/src/tests/datasets/test_samsum_datasets.py @@ -5,11 +5,19 @@ import pytest from dataclasses import dataclass from functools import partial from unittest.mock import patch +from datasets import load_dataset @dataclass class Config: model_type: str = "llama" +try: + load_dataset("Samsung/samsum") + SAMSUM_UNAVAILABLE = False +except ValueError: + SAMSUM_UNAVAILABLE = True + +@pytest.mark.skipif(SAMSUM_UNAVAILABLE, reason="Samsum dataset is unavailable") @pytest.mark.skip_missing_tokenizer @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.AutoTokenizer')