Skip to content
Snippets Groups Projects
Commit 707af7ea authored by Hamid Shojanazeri's avatar Hamid Shojanazeri
Browse files

adding cuda:0 for non-fsdp situations

parent 1e0f8a1f
No related branches found
No related tags found
No related merge requests found
......@@ -84,7 +84,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
else:
batch[key] = batch[key].to('cuda')
batch[key] = batch[key].to('cuda:0')
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment