diff --git a/examples/custom_dataset.py b/examples/custom_dataset.py index dc9c20c95474444d4d2154283bd328eddce5c4d0..8a5f1da006f30ba5769d61bf710c040b24b6193c 100644 --- a/examples/custom_dataset.py +++ b/examples/custom_dataset.py @@ -27,7 +27,7 @@ def tokenize_dialog(dialog, tokenizer): combined_tokens = {} for k in dialog_tokens[0].keys(): - combined_tokens[k] = [list(itertools.chain(*(t[k] for t in dialog_tokens)))] + combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens))) return combined_tokens diff --git a/tests/datasets/test_custom_dataset.py b/tests/datasets/test_custom_dataset.py index 151a2822e979862f4841f53eb9fc83eea70ef7a1..c519de30f914cb3b954508385b17a1212179dd4d 100644 --- a/tests/datasets/test_custom_dataset.py +++ b/tests/datasets/test_custom_dataset.py @@ -17,7 +17,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker): "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here "custom_dataset.file": "examples/custom_dataset.py", "custom_dataset.train_split": "validation", - "batch_size_training": 1, + "batch_size_training": 2, "use_peft": False, } @@ -30,9 +30,9 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker): eval_dataloader = args[2] tokenizer = args[3] - assert len(train_dataloader) == 2241 - assert len(eval_dataloader) == 2241 - + assert len(train_dataloader) == 226 + assert len(eval_dataloader) == 2*226 + STRING = tokenizer.decode(next(iter(train_dataloader))["input_ids"][0], skip_special_tokens=True) EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "