Skip to content
Snippets Groups Projects
Commit e9559d26 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri
Browse files

fixing the train/eval_loss calcualtion

parent 4ba4400a
No related branches found
No related tags found
No related merge requests found
...@@ -66,7 +66,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche ...@@ -66,7 +66,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
scaler = ShardedGradScaler() scaler = ShardedGradScaler()
elif train_config.use_fp16 and not train_config.enable_fsdp: elif train_config.use_fp16 and not train_config.enable_fsdp:
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
if train_config.enable_fsdp:
world_size = int(os.environ["WORLD_SIZE"])
train_prep = [] train_prep = []
train_loss = [] train_loss = []
val_prep = [] val_prep = []
...@@ -102,12 +103,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche ...@@ -102,12 +103,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
if train_config.enable_fsdp:
print(f"\n step {step} is completed and loss is {loss.detach().float()}") if rank==0:
print(f"\n step {step} is completed and loss is {loss.detach().float()}")
else:
print(f"\n step {step} is completed and loss is {loss.detach().float()}")
# Reducing total_loss across all devices if there's more than one CUDA device # Reducing total_loss across all devices if there's more than one CUDA device
if torch.cuda.device_count() > 1 and train_config.enable_fsdp: if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
train_epoch_loss = total_loss / data_set_len train_epoch_loss = total_loss / len(train_dataloader)
if train_config.enable_fsdp:
train_epoch_loss = train_epoch_loss/world_size
train_perplexity = torch.exp(train_epoch_loss) train_perplexity = torch.exp(train_epoch_loss)
train_prep.append(train_perplexity) train_prep.append(train_perplexity)
...@@ -127,11 +134,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche ...@@ -127,11 +134,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if train_config.save_model and eval_epoch_loss < best_val_loss: if train_config.save_model and eval_epoch_loss < best_val_loss:
dist.barrier() dist.barrier()
if train_config.use_peft: if train_config.use_peft:
print(f"we are in the saving the PEFT modules") print(f"we are in the saving the PEFT modules")
model.save_pretrained(train_config.output_dir) model.save_pretrained(train_config.output_dir)
print(f"PEFT modules are saved in {train_config.output_dir} directory") print(f"PEFT modules are saved in {train_config.output_dir} directory")
else: else:
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
...@@ -139,16 +144,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche ...@@ -139,16 +144,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
model, optimizer, rank, train_config, epoch=epoch model, optimizer, rank, train_config, epoch=epoch
) )
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
print(" we are about to save the models *******") print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
print("=====================================================")
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config) model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
if train_config.save_optimizer: if train_config.save_optimizer:
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
print("=====================================================")
if not train_config.use_peft and train_config.save_optimizer: if not train_config.use_peft and train_config.save_optimizer:
model_checkpointing.save_optimizer_checkpoint( model_checkpointing.save_optimizer_checkpoint(
model, optimizer, rank, train_config, epoch=epoch model, optimizer, rank, train_config, epoch=epoch
) )
print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
print("=====================================================")
dist.barrier() dist.barrier()
if eval_epoch_loss < best_val_loss: if eval_epoch_loss < best_val_loss:
...@@ -192,6 +202,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): ...@@ -192,6 +202,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
Returns: eval_ppl, eval_epoch_loss Returns: eval_ppl, eval_epoch_loss
""" """
if train_config.enable_fsdp:
world_size = int(os.environ["WORLD_SIZE"])
model.eval() model.eval()
eval_preds = [] eval_preds = []
eval_loss = 0.0 # Initialize evaluation loss eval_loss = 0.0 # Initialize evaluation loss
...@@ -223,7 +235,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): ...@@ -223,7 +235,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
# Compute average loss and perplexity # Compute average loss and perplexity
eval_epoch_loss = eval_loss / eval_dataset_len eval_epoch_loss = eval_loss / len(eval_dataloader)
if train_config.enable_fsdp:
eval_epoch_loss = eval_epoch_loss/world_size
eval_ppl = torch.exp(eval_epoch_loss) eval_ppl = torch.exp(eval_epoch_loss)
# Print evaluation metrics # Print evaluation metrics
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment