diff --git a/habitat_baselines/train_ppo.py b/habitat_baselines/train_ppo.py index f6586af217f7f6b726d2d40b2933f5cf41b9b997..97efa74a62ac5f715445c2b5e021dc00cc1a73c0 100644 --- a/habitat_baselines/train_ppo.py +++ b/habitat_baselines/train_ppo.py @@ -118,14 +118,18 @@ def construct_envs(args): ) scene_split_size = int(np.floor(len(scenes) / args.num_processes)) + scene_splits = [[] for _ in range(args.num_processes)] + for j, s in enumerate(scenes): + scene_splits[j % len(scene_splits)].append(s) + + assert sum(map(len, scene_splits)) == len(scenes) + for i in range(args.num_processes): config_env = cfg_env(config_paths=args.task_config, opts=args.opts) config_env.defrost() if len(scenes) > 0: - config_env.DATASET.CONTENT_SCENES = scenes[ - i * scene_split_size : (i + 1) * scene_split_size - ] + config_env.DATASET.POINTNAVV1.CONTENT_SCENES = scene_splits[i] config_env.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = args.sim_gpu_id