From ec00a2f72276f7a0e3c341c55f72c73d4f74ddba Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Fri, 8 Sep 2023 16:57:43 -0700 Subject: [PATCH] Fix batching error --- examples/custom_dataset.py | 2 +- tests/datasets/test_custom_dataset.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/custom_dataset.py b/examples/custom_dataset.py index dc9c20c9..8a5f1da0 100644 --- a/examples/custom_dataset.py +++ b/examples/custom_dataset.py @@ -27,7 +27,7 @@ def tokenize_dialog(dialog, tokenizer): combined_tokens = {} for k in dialog_tokens[0].keys(): - combined_tokens[k] = [list(itertools.chain(*(t[k] for t in dialog_tokens)))] + combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens))) return combined_tokens diff --git a/tests/datasets/test_custom_dataset.py b/tests/datasets/test_custom_dataset.py index 151a2822..c519de30 100644 --- a/tests/datasets/test_custom_dataset.py +++ b/tests/datasets/test_custom_dataset.py @@ -17,7 +17,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker): "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here "custom_dataset.file": "examples/custom_dataset.py", "custom_dataset.train_split": "validation", - "batch_size_training": 1, + "batch_size_training": 2, "use_peft": False, } @@ -30,9 +30,9 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker): eval_dataloader = args[2] tokenizer = args[3] - assert len(train_dataloader) == 2241 - assert len(eval_dataloader) == 2241 - + assert len(train_dataloader) == 226 + assert len(eval_dataloader) == 2*226 + STRING = tokenizer.decode(next(iter(train_dataloader))["input_ids"][0], skip_special_tokens=True) EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, " -- GitLab