Skip to content
Snippets Groups Projects
Commit fce8c260 authored by Matthias Reso's avatar Matthias Reso
Browse files

[WIP]add fake tokenizer

parent fac71cd1
Branches
No related tags found
No related merge requests found
......@@ -6,7 +6,12 @@ import pytest
from transformers import AutoTokenizer
ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
try:
AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
except OSError:
LLAMA_VERSIONS = ["fake_llama"]
@pytest.fixture(params=LLAMA_VERSIONS)
def llama_version(request):
......@@ -17,10 +22,35 @@ def llama_version(request):
def model_type(request):
return request.param
class FakeTokenier(object):
def __init__(self):
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
self.sep_token_id = 3
self.pad_token = "<|pad_id|>"
self.bos_token = "<|bos_id|>"
self.eos_token = "<|eos_id|>"
self.sep_token = "<|sep_id|>"
def __call__(self, *args, **kwargs):
return self.encode(*args, **kwargs)
def encode(self, text, *args, **kwargs):
breakpoint()
return [len(c) for c in text.split(" ")]
def __len__(self):
return 128256
@pytest.fixture(scope="module")
def llama_tokenizer(request):
return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
if LLAMA_VERSIONS == ["fake_llama"]:
return {"fake_llama": FakeTokenier()}
else:
return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
@pytest.fixture
......
......@@ -2,8 +2,13 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import pytest
from dataclasses import dataclass
from unittest.mock import patch
@dataclass
class Config:
model_type: str = "llama"
EXPECTED_SAMPLE_NUMBER ={
"meta-llama/Llama-2-7b-hf": {
"train": 96,
......@@ -12,20 +17,35 @@ EXPECTED_SAMPLE_NUMBER ={
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
"train": 79,
"eval": 34,
},
"fake_llama": {
"train": 48,
"eval": 34,
}
}
@pytest.mark.skip_missing_tokenizer
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.AutoTokenizer')
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
def test_packing(
step_lr,
optimizer,
get_model,
get_config,
tokenizer,
train,
setup_tokenizer,
llama_version,
model_type,
):
from llama_recipes.finetuning import main
setup_tokenizer(tokenizer)
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
get_config.return_value = Config(model_type=model_type)
kwargs = {
"model_name": llama_version,
......@@ -45,20 +65,24 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenize
eval_dataloader = args[2]
assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
# assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
# print(f"{len(eval_dataloader)=}")
batch = next(iter(train_dataloader))
# batch = next(iter(train_dataloader))
assert "labels" in batch.keys()
assert "input_ids" in batch.keys()
assert "attention_mask" in batch.keys()
# assert "labels" in batch.keys()
# assert "input_ids" in batch.keys()
# assert "attention_mask" in batch.keys()
assert batch["labels"][0].size(0) == 4096
assert batch["input_ids"][0].size(0) == 4096
assert batch["attention_mask"][0].size(0) == 4096
# # assert batch["labels"][0].size(0) == 4096
# # assert batch["input_ids"][0].size(0) == 4096
# # assert batch["attention_mask"][0].size(0) == 4096
# print(batch["labels"][0].size(0))
# print(batch["input_ids"][0].size(0))
# print(batch["attention_mask"][0].size(0))
@pytest.mark.skip_missing_tokenizer
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.AutoTokenizer')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment