diff --git a/src/tests/conftest.py b/src/tests/conftest.py
index 7fef0f5bdf1a1e04c286e59b543d5423e93869b2..1476bf3c1e2d4fb2c492b32a88af9a4f7ddab0c0 100644
--- a/src/tests/conftest.py
+++ b/src/tests/conftest.py
@@ -3,15 +3,13 @@
 
 import pytest
 
-from transformers import AutoTokenizer
+from utils import maybe_tokenizer
 
-ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+ACCESS_ERROR_MSG = "Could not access tokenizer. Did you log into huggingface hub and provided the correct token?"
 
-try:
-    AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
-    LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
-except OSError:
-    LLAMA_VERSIONS = ["fake_llama"]
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct", "fake_llama"]
+
+LLAMA_TOKENIZERS = {k: maybe_tokenizer(k) for k in LLAMA_VERSIONS}
 
 @pytest.fixture(params=LLAMA_VERSIONS)
 def llama_version(request):
@@ -22,35 +20,10 @@ def llama_version(request):
 def model_type(request):
     return request.param
 
-class FakeTokenier(object):
-    def __init__(self):
-        self.pad_token_id = 0
-        self.bos_token_id = 1
-        self.eos_token_id = 2
-        self.sep_token_id = 3
-
-        self.pad_token = "<|pad_id|>"
-        self.bos_token = "<|bos_id|>"
-        self.eos_token = "<|eos_id|>"
-        self.sep_token = "<|sep_id|>"
-
-    def __call__(self, *args, **kwargs):
-        return self.encode(*args, **kwargs)
-
-    def encode(self, text, *args, **kwargs):
-        breakpoint()
-        return [len(c) for c in text.split(" ")]
-    
-    def __len__(self):
-        return 128256
-
 
 @pytest.fixture(scope="module")
 def llama_tokenizer(request):
-    if LLAMA_VERSIONS == ["fake_llama"]:
-        return {"fake_llama": FakeTokenier()}
-    else:
-        return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
+    return LLAMA_TOKENIZERS
 
 
 @pytest.fixture
@@ -61,6 +34,13 @@ def setup_tokenizer(llama_tokenizer, llama_version):
 
     return _helper
 
+@pytest.fixture
+def setup_processor(llama_tokenizer, llama_version):
+    def _helper(processor_mock):
+        processor_mock.from_pretrained.return_value.tokenizer = llama_tokenizer[llama_version]
+
+    return _helper
+
 
 def pytest_addoption(parser):
     parser.addoption(
@@ -73,16 +53,18 @@ def pytest_configure(config):
 
 
 def pytest_collection_modifyitems(config, items):
+    #skip tests marked with skip_missing_tokenizer if tokenizer is unavailable unless --unskip-missing-tokenizer is passed
     if config.getoption("--unskip-missing-tokenizer"):
         return
 
-    try:
-        AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
-        tokenizer_available = True
-    except OSError:
-        tokenizer_available = False
-
     skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
     for item in items:
-        if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:
+        # get the tokenizer for the test
+        version = [v for v in LLAMA_VERSIONS for i in item.keywords if v in i]
+        if len(version) == 0:
+            # no tokenizer used in this test
+            continue
+        version = version.pop()
+        assert version in LLAMA_TOKENIZERS
+        if "skip_missing_tokenizer" in item.keywords and LLAMA_TOKENIZERS[version] is None:
             item.add_marker(skip_missing_tokenizer)
diff --git a/src/tests/datasets/test_custom_dataset.py b/src/tests/datasets/test_custom_dataset.py
index 4fdccfafbccc7446accdfd9a3b93ccdde7743577..f842733b7d038de11e0ddcf130ed4c41d05be946 100644
--- a/src/tests/datasets/test_custom_dataset.py
+++ b/src/tests/datasets/test_custom_dataset.py
@@ -2,6 +2,7 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import pytest
+from contextlib import nullcontext
 from unittest.mock import patch
 
 from transformers import LlamaTokenizer
@@ -133,13 +134,16 @@ def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version)
         {"role":"assistant", "content":"Romans"},
     ]
 
-    result = tokenize_dialog(dialog, tokenizer)
+    c = pytest.raises(AttributeError) if llama_version == "fake_llama" else nullcontext()
+
+    with c:
+        result = tokenize_dialog(dialog, tokenizer)
     
     if "Llama-2" in llama_version:
         assert result["labels"][:12] == [-100] * 12
         assert result["labels"][17:28] == [-100] * 11
         assert result["labels"].count(-100) == 11 + 12
-    else:
+    elif "Llama-3" in llama_version:
         assert result["labels"][:38] == [-100] * 38
         assert result["labels"][43:54] == [-100] * 11
         assert result["labels"].count(-100) == 38 + 11
diff --git a/src/tests/datasets/test_samsum_datasets.py b/src/tests/datasets/test_samsum_datasets.py
index 9fd5bcafb7cf2696ebd918d4c8119619d4291485..8e3c2795e378b32980358fbd73e1df02644a901c 100644
--- a/src/tests/datasets/test_samsum_datasets.py
+++ b/src/tests/datasets/test_samsum_datasets.py
@@ -2,20 +2,42 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import pytest
+from dataclasses import dataclass
 from functools import partial
 from unittest.mock import patch
 
+@dataclass
+class Config:
+    model_type: str = "llama"
+
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
+def test_samsum_dataset(
+    step_lr,
+    optimizer,
+    get_model,
+    get_mmodel,
+    processor,
+    get_config,
+    tokenizer,
+    train,
+    mocker,
+    setup_tokenizer,
+    llama_version,
+    ):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_config.return_value = Config()
 
     BATCH_SIZE = 8
     kwargs = {
diff --git a/src/tests/test_batching.py b/src/tests/test_batching.py
index 3aa20fdecde95c0c54f4fa38e4590b1349b4923f..81f7d25bd6857e29b885d7b119a0a95402c38009 100644
--- a/src/tests/test_batching.py
+++ b/src/tests/test_batching.py
@@ -3,6 +3,7 @@
 
 import pytest
 from dataclasses import dataclass
+from contextlib import nullcontext
 from unittest.mock import patch
 
 @dataclass
@@ -19,14 +20,16 @@ EXPECTED_SAMPLE_NUMBER ={
         "eval": 34,
     },
     "fake_llama": {
-        "train": 48,
-        "eval": 34,
+        "train": 50,
+        "eval": 21,
     }
 }
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
 @patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
@@ -34,17 +37,22 @@ def test_packing(
     step_lr,
     optimizer,
     get_model,
+    get_mmodel,
+    processor,
     get_config,
     tokenizer,
     train,
     setup_tokenizer,
+    setup_processor,
     llama_version,
     model_type,
     ):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
+    setup_processor(processor)
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
     get_config.return_value = Config(model_type=model_type)
     
     kwargs = {
@@ -56,35 +64,38 @@ def test_packing(
         "batching_strategy": "packing",
         }
 
-    main(**kwargs)
+    c = nullcontext() if model_type == "llama" else  pytest.raises(ValueError)
 
-    assert train.call_count == 1
+    with c:
+        main(**kwargs)
+    
+    if model_type == "llama":
+        assert train.call_count == 1
 
-    args, kwargs = train.call_args
-    train_dataloader = args[1]
-    eval_dataloader = args[2]
+        args, kwargs = train.call_args
+        train_dataloader = args[1]
+        eval_dataloader = args[2]
 
-    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
-    # assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
-    # print(f"{len(eval_dataloader)=}")
+        assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
+        assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
 
-    # batch = next(iter(train_dataloader))
+        batch = next(iter(train_dataloader))
 
-    # assert "labels" in batch.keys()
-    # assert "input_ids" in batch.keys()
-    # assert "attention_mask" in batch.keys()
+        assert "labels" in batch.keys()
+        assert "input_ids" in batch.keys()
+        assert "attention_mask" in batch.keys()
 
-    # # assert batch["labels"][0].size(0) == 4096
-    # # assert batch["input_ids"][0].size(0) == 4096
-    # # assert batch["attention_mask"][0].size(0) == 4096
-    # print(batch["labels"][0].size(0))
-    # print(batch["input_ids"][0].size(0))
-    # print(batch["attention_mask"][0].size(0))
-    
+        assert batch["labels"][0].size(0) == 4096
+        assert batch["input_ids"][0].size(0) == 4096
+        assert batch["attention_mask"][0].size(0) == 4096
 
 
+@patch("llama_recipes.finetuning.torch.cuda.is_available")
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
@@ -92,12 +103,34 @@ def test_packing(
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.utils.config_utils.dist')
-def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
+def test_distributed_packing(
+    dist,
+    is_initialized,
+    fsdp,
+    setup,
+    step_lr,
+    optimizer,
+    get_model,
+    get_mmodel,
+    processor,
+    get_config,
+    tokenizer,
+    train,
+    cuda_is_available,
+    setup_tokenizer,
+    setup_processor,
+    llama_version,
+    model_type,
+    ):
     import os
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
+    setup_processor(processor)
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_config.return_value = Config(model_type=model_type)
+    cuda_is_available.return_value = False
 
     rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
@@ -120,13 +153,17 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     dist.get_rank.return_value = rank
     dist.get_world_size.return_value = 2
 
-    main(**kwargs)
+    c = nullcontext() if model_type == "llama" else  pytest.raises(ValueError)
+
+    with c:
+        main(**kwargs)
 
-    assert train.call_count == 1
+    if model_type == "llama":
+        assert train.call_count == 1
 
-    args, kwargs = train.call_args
-    train_dataloader = args[1]
-    eval_dataloader = args[2]
+        args, kwargs = train.call_args
+        train_dataloader = args[1]
+        eval_dataloader = args[2]
 
-    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
-    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2
+        assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
+        assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2
diff --git a/src/tests/test_chat_completion.py b/src/tests/test_chat_completion.py
index b49054f635717fa0410e1eaf1e3d6b5e7b489353..266252317819bded34f76fdcbcea75c192327e14 100644
--- a/src/tests/test_chat_completion.py
+++ b/src/tests/test_chat_completion.py
@@ -52,8 +52,8 @@ def _format_tokens_llama3(dialogs, tokenizer):
 def test_chat_completion(
     load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
 ):
-    if "Llama-2" in llama_version:
-        pytest.skip("skipping test for Llama-2")
+    if "Llama-2" in llama_version or llama_version == "fake_llama":
+        pytest.skip(f"skipping test for {llama_version}")
 
     from chat_completion import main
 
diff --git a/src/tests/utils.py b/src/tests/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dbf30746a66061cc9d781c1ee38e4136bf75c74
--- /dev/null
+++ b/src/tests/utils.py
@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from transformers import AutoTokenizer
+
+
+class FakeTokenizer(object):
+    def __init__(self):
+        self.pad_token_id = 0
+        self.bos_token_id = 42
+        self.eos_token_id = 43
+        self.sep_token_id = 3
+        self.vocab_size = 128256
+
+        self.pad_token = "<|pad_id|>"
+        self.bos_token = "<|bos_id|>"
+        self.eos_token = "<|eos_id|>"
+        self.sep_token = "<|sep_id|>"
+        self.tokenizer = self
+        self.padding_side = "left"
+
+    def __call__(self, *args, **kwargs):
+        print(f"{kwargs=}")
+        ids = self.encode(*args, **kwargs)
+        return {"input_ids": ids}
+
+    def encode(self, text, *args, **kwargs):
+        return [self.bos_token_id] + [len(c) for c in text.split(" ")] + [self.eos_token_id]
+    
+    def __len__(self):
+        return 128256
+    
+    def pad(self, *args, **kwargs):
+        args = args[0]
+        max_len = max([len(a["input_ids"]) for a in args])
+        for a in args:
+            for k in a.keys():
+                a[k] = a[k] + ([self.pad_token_id if k == "input_ids" else 0] * (max_len - len(a)))
+        out = {}
+        for k in args[0].keys():
+            out[k] = [a[k] for a in args]
+        return out
+
+
+def maybe_tokenizer(name):
+    if name == "fake_llama":
+        return FakeTokenizer()
+    try:
+        return AutoTokenizer.from_pretrained(name)
+    except OSError:
+        return None