diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 79f971a0a9fb35a4152c3c96d5eb941bc0a13df6..013c8f6be1aae46ea133b671292f337ccbedecb9 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -3,8 +3,9 @@ import os from pkg_resources import packaging - +import gc import fire + import torch import torch.distributed as dist import torch.optim as optim @@ -44,8 +45,9 @@ from llama_recipes.utils.train_utils import ( get_policies ) - +import gc def main(**kwargs): + gc.disable() # Update the configuration for the training and sharding process update_config((train_config, fsdp_config), **kwargs) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index f545ec8a13a0672126aa38594ed62ce97a07e859..b561992f116e0e21df2411c62dc4184f889bce9a 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -7,7 +7,7 @@ import yaml from pathlib import Path from pkg_resources import packaging import contextlib - +import gc import torch import torch.cuda.nccl as nccl @@ -100,6 +100,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) with maybe_run_profiler(train_config) as torch_profiler: for step, batch in enumerate(train_dataloader): + if step > 5: + break + gc.collect(1) for key in batch.keys(): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) @@ -285,6 +288,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): eval_loss = 0.0 # Initialize evaluation loss with MemoryTrace() as memtrace: for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): + if step > 5: + break + gc.collect(1) for key in batch.keys(): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank)