diff --git a/utils/train_utils.py b/utils/train_utils.py index 7366e3f3e33380eb31844df2176f56e594a06ee6..8a68f0c18b7e151d7c2b3b13de5627a87a8aa901 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -84,7 +84,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) else: - batch[key] = batch[key].to('cuda') + batch[key] = batch[key].to('cuda:0') outputs = model(**batch) loss = outputs.loss loss = loss / gradient_accumulation_steps