diff --git a/generate_samples.py b/generate_samples.py index c6876a531c518e80ef93aa096d91984e18bd0ccb..b357164612634f6bced09335591d5424576e6410 100755 --- a/generate_samples.py +++ b/generate_samples.py @@ -56,7 +56,7 @@ def setup_model(args): iteration, release, success = get_checkpoint_iteration(args) path = os.path.join(args.load, str(iteration), "mp_rank_00_model_states.pt") print('current device:', torch.cuda.current_device()) - checkpoint = torch.load(path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device())) + checkpoint = torch.load(path, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["module"]) print(f"Load model file {path}") else: