diff --git a/utils/train_utils.py b/utils/train_utils.py
index c6745d5ccbc1dc3e9c8080f74836a7cccebece2d..defb6518da833a97bddd7fe2f6261c55f349e9a0 100644
--- a/utils/train_utils.py
+++ b/utils/train_utils.py
@@ -85,7 +85,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
-                    elif not train_config.quantization:
+                    else:
                         batch[key] = batch[key].to('cuda')       
                 outputs = model(**batch)
                 loss = outputs.loss