diff --git a/utils/train_utils.py b/utils/train_utils.py index ce16b8033b478932edf85997a8b6576e43430c4c..7421907585ba912e4afff3bf59cc42f45ff0ff06 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -84,7 +84,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) else: - batch[key] = batch[key].to('cuda') + + batch[key] = batch[key].to('cuda:0') loss = model(**batch).loss loss = loss / gradient_accumulation_steps total_loss += loss.detach().float() @@ -198,7 +199,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) else: - batch[key] = batch[key].to('cuda') + batch[key] = batch[key].to('cuda:0') # Ensure no gradients are computed for this scope to save memory with torch.no_grad(): # Forward pass and compute loss