diff --git a/generate_samples.py b/generate_samples.py
index c6876a531c518e80ef93aa096d91984e18bd0ccb..b357164612634f6bced09335591d5424576e6410 100755
--- a/generate_samples.py
+++ b/generate_samples.py
@@ -56,7 +56,7 @@ def setup_model(args):
             iteration, release, success = get_checkpoint_iteration(args)
             path = os.path.join(args.load, str(iteration), "mp_rank_00_model_states.pt")
             print('current device:', torch.cuda.current_device())
-            checkpoint = torch.load(path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device()))
+            checkpoint = torch.load(path, map_location=torch.device('cpu'))
             model.load_state_dict(checkpoint["module"])
             print(f"Load model file {path}")
         else: