From 6217635e8779d7fc68f58fe7ecf7118b5a76d454 Mon Sep 17 00:00:00 2001 From: "hongbo.mo" <hongbo.mo@upai.com> Date: Fri, 15 Sep 2023 11:16:23 +0800 Subject: [PATCH] Fix tqdm bar not change length after terminal is resized --- src/llama_recipes/datasets/utils.py | 4 ++-- src/llama_recipes/utils/train_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama_recipes/datasets/utils.py b/src/llama_recipes/datasets/utils.py index 4c6956d8..0a11d8c3 100644 --- a/src/llama_recipes/datasets/utils.py +++ b/src/llama_recipes/datasets/utils.py @@ -52,7 +52,7 @@ class ConcatDataset(Dataset): "labels": [], } - for sample in tqdm(self.dataset, desc="Preprocessing dataset"): + for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): buffer = {k: v + sample[k] for k,v in buffer.items()} while len(next(iter(buffer.values()))) > self.chunk_size: @@ -63,4 +63,4 @@ class ConcatDataset(Dataset): return self.samples[idx] def __len__(self): - return len(self.samples) \ No newline at end of file + return len(self.samples) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 2f8faaee..c7aa6ff6 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -69,7 +69,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche model.train() total_loss = 0.0 total_length = len(train_dataloader)//gradient_accumulation_steps - pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length) + pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) for step, batch in enumerate(train_dataloader): for key in batch.keys(): if train_config.enable_fsdp: @@ -227,7 +227,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): eval_preds = [] eval_loss = 0.0 # Initialize evaluation loss with MemoryTrace() as memtrace: - for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")): + for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): for key in batch.keys(): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) -- GitLab