From 19089269d3b49088fef854059b371d9fb575da08 Mon Sep 17 00:00:00 2001
From: Hamid Shojanazeri <hamid.nazeri2010@gmail.com>
Date: Thu, 11 Jan 2024 06:23:06 +0000
Subject: [PATCH] add gc

---
 src/llama_recipes/finetuning.py        | 6 ++++--
 src/llama_recipes/utils/train_utils.py | 8 +++++++-
 2 files changed, 11 insertions(+), 3 deletions(-)

diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 79f971a0..013c8f6b 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -3,8 +3,9 @@
 
 import os
 from pkg_resources import packaging
-
+import gc
 import fire
+
 import torch
 import torch.distributed as dist
 import torch.optim as optim
@@ -44,8 +45,9 @@ from llama_recipes.utils.train_utils import (
     get_policies
 )
 
-
+import gc
 def main(**kwargs):
+    gc.disable()
     # Update the configuration for the training and sharding process
     update_config((train_config, fsdp_config), **kwargs)
 
diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index f545ec8a..b561992f 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -7,7 +7,7 @@ import yaml
 from pathlib import Path
 from pkg_resources import packaging
 import contextlib
-
+import gc
 
 import torch
 import torch.cuda.nccl as nccl
@@ -100,6 +100,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
             with maybe_run_profiler(train_config) as torch_profiler:
                 for step, batch in enumerate(train_dataloader):
+                    if step > 5:
+                        break
+                    gc.collect(1)
                     for key in batch.keys():
                         if train_config.enable_fsdp:
                             batch[key] = batch[key].to(local_rank)
@@ -285,6 +288,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
+            if step > 5:
+                break
+            gc.collect(1)
             for key in batch.keys():
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
-- 
GitLab