From b4ad1cace5019dd2a51353534bac9747aa5071d7 Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Thu, 17 Jun 2021 23:49:34 +0800 Subject: [PATCH] load weights to cpu instead of gpu --- generate_samples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generate_samples.py b/generate_samples.py index c6876a5..b357164 100755 --- a/generate_samples.py +++ b/generate_samples.py @@ -56,7 +56,7 @@ def setup_model(args): iteration, release, success = get_checkpoint_iteration(args) path = os.path.join(args.load, str(iteration), "mp_rank_00_model_states.pt") print('current device:', torch.cuda.current_device()) - checkpoint = torch.load(path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device())) + checkpoint = torch.load(path, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["module"]) print(f"Load model file {path}") else: -- GitLab