From 19089269d3b49088fef854059b371d9fb575da08 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri <hamid.nazeri2010@gmail.com> Date: Thu, 11 Jan 2024 06:23:06 +0000 Subject: [PATCH] add gc --- src/llama_recipes/finetuning.py | 6 ++++-- src/llama_recipes/utils/train_utils.py | 8 +++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 79f971a0..013c8f6b 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -3,8 +3,9 @@ import os from pkg_resources import packaging - +import gc import fire + import torch import torch.distributed as dist import torch.optim as optim @@ -44,8 +45,9 @@ from llama_recipes.utils.train_utils import ( get_policies ) - +import gc def main(**kwargs): + gc.disable() # Update the configuration for the training and sharding process update_config((train_config, fsdp_config), **kwargs) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index f545ec8a..b561992f 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -7,7 +7,7 @@ import yaml from pathlib import Path from pkg_resources import packaging import contextlib - +import gc import torch import torch.cuda.nccl as nccl @@ -100,6 +100,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) with maybe_run_profiler(train_config) as torch_profiler: for step, batch in enumerate(train_dataloader): + if step > 5: + break + gc.collect(1) for key in batch.keys(): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) @@ -285,6 +288,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): eval_loss = 0.0 # Initialize evaluation loss with MemoryTrace() as memtrace: for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): + if step > 5: + break + gc.collect(1) for key in batch.keys(): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) -- GitLab