From ec00a2f72276f7a0e3c341c55f72c73d4f74ddba Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Fri, 8 Sep 2023 16:57:43 -0700
Subject: [PATCH] Fix batching error

---
 examples/custom_dataset.py            | 2 +-
 tests/datasets/test_custom_dataset.py | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/examples/custom_dataset.py b/examples/custom_dataset.py
index dc9c20c9..8a5f1da0 100644
--- a/examples/custom_dataset.py
+++ b/examples/custom_dataset.py
@@ -27,7 +27,7 @@ def tokenize_dialog(dialog, tokenizer):
     
     combined_tokens = {}  
     for k in dialog_tokens[0].keys():
-        combined_tokens[k] = [list(itertools.chain(*(t[k] for t in dialog_tokens)))]
+        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
     return combined_tokens
 
 
diff --git a/tests/datasets/test_custom_dataset.py b/tests/datasets/test_custom_dataset.py
index 151a2822..c519de30 100644
--- a/tests/datasets/test_custom_dataset.py
+++ b/tests/datasets/test_custom_dataset.py
@@ -17,7 +17,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
-        "batch_size_training": 1,
+        "batch_size_training": 2,
         "use_peft": False,
         }
     
@@ -30,9 +30,9 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     tokenizer = args[3]
     
-    assert len(train_dataloader) == 2241
-    assert len(eval_dataloader) == 2241
-    
+    assert len(train_dataloader) == 226
+    assert len(eval_dataloader) == 2*226
+
     STRING = tokenizer.decode(next(iter(train_dataloader))["input_ids"][0], skip_special_tokens=True)
     EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
     
-- 
GitLab