Skip to content
Snippets Groups Projects
Commit 19089269 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri
Browse files

add gc

parent ef810bbe
No related branches found
No related tags found
No related merge requests found
......@@ -3,8 +3,9 @@
import os
from pkg_resources import packaging
import gc
import fire
import torch
import torch.distributed as dist
import torch.optim as optim
......@@ -44,8 +45,9 @@ from llama_recipes.utils.train_utils import (
get_policies
)
import gc
def main(**kwargs):
gc.disable()
# Update the configuration for the training and sharding process
update_config((train_config, fsdp_config), **kwargs)
......
......@@ -7,7 +7,7 @@ import yaml
from pathlib import Path
from pkg_resources import packaging
import contextlib
import gc
import torch
import torch.cuda.nccl as nccl
......@@ -100,6 +100,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
with maybe_run_profiler(train_config) as torch_profiler:
for step, batch in enumerate(train_dataloader):
if step > 5:
break
gc.collect(1)
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
......@@ -285,6 +288,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
eval_loss = 0.0 # Initialize evaluation loss
with MemoryTrace() as memtrace:
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
if step > 5:
break
gc.collect(1)
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment