Skip to content
Snippets Groups Projects
Commit ccad9a0d authored by JasonJiazhiZhang's avatar JasonJiazhiZhang Committed by Oleksandr
Browse files

Fix multiprocess eval (#150)


* episode stats changed to dict of dicts

* handle non-unique episode_id

* use tuple as key

* update while loop checking

Co-Authored-By: default avatarErik Wijmans <ewijmans2@gmail.com>
parent 7015813a
No related branches found
No related tags found
No related merge requests found
......@@ -143,24 +143,20 @@ def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
for sensor in batch:
batch[sensor] = batch[sensor].to(device)
episode_rewards = torch.zeros(envs.num_envs, 1, device=device)
episode_spls = torch.zeros(envs.num_envs, 1, device=device)
episode_success = torch.zeros(envs.num_envs, 1, device=device)
episode_counts = torch.zeros(envs.num_envs, 1, device=device)
current_episode_reward = torch.zeros(envs.num_envs, 1, device=device)
test_recurrent_hidden_states = torch.zeros(
args.num_processes, args.hidden_size, device=device
)
not_done_masks = torch.zeros(args.num_processes, 1, device=device)
stats_episodes = set()
stats_episodes = dict() # dict of dicts that stores stats per episode
rgb_frames = None
if args.video_option:
rgb_frames = [[]] * args.num_processes
os.makedirs(args.video_dir, exist_ok=True)
while episode_counts.sum() < args.count_test_episodes:
while len(stats_episodes) < args.count_test_episodes and envs.num_envs > 0:
current_episodes = envs.current_episodes()
with torch.no_grad():
......@@ -184,30 +180,34 @@ def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
device=device,
)
for i in range(not_done_masks.shape[0]):
if not_done_masks[i].item() == 0:
episode_spls[i] += infos[i]["spl"]
if infos[i]["spl"] > 0:
episode_success[i] += 1
rewards = torch.tensor(
rewards, dtype=torch.float, device=device
).unsqueeze(1)
current_episode_reward += rewards
episode_rewards += (1 - not_done_masks) * current_episode_reward
episode_counts += 1 - not_done_masks
current_episode_reward *= not_done_masks
next_episodes = envs.current_episodes()
envs_to_pause = []
n_envs = envs.num_envs
for i in range(n_envs):
if next_episodes[i].episode_id in stats_episodes:
if (
next_episodes[i].scene_id,
next_episodes[i].episode_id,
) in stats_episodes:
envs_to_pause.append(i)
# episode ended
if not_done_masks[i].item() == 0:
stats_episodes.add(current_episodes[i].episode_id)
episode_stats = dict()
episode_stats["spl"] = infos[i]["spl"]
episode_stats["success"] = int(infos[i]["spl"] > 0)
episode_stats["reward"] = current_episode_reward[i].item()
current_episode_reward[i] = 0
# use scene_id + episode_id as unique id for storing stats
stats_episodes[
(
current_episodes[i].scene_id,
current_episodes[i].episode_id,
)
] = episode_stats
if args.video_option:
generate_video(
args,
......@@ -224,7 +224,7 @@ def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
frame = observations_to_image(observations[i], infos[i])
rgb_frames[i].append(frame)
# stop tracking ended episodes if they exist
# pausing envs with no new episode
if len(envs_to_pause) > 0:
state_index = list(range(envs.num_envs))
for idx in reversed(envs_to_pause):
......@@ -233,7 +233,7 @@ def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
# indexing along the batch dimensions
test_recurrent_hidden_states = test_recurrent_hidden_states[
:, state_index
state_index
]
not_done_masks = not_done_masks[state_index]
current_episode_reward = current_episode_reward[state_index]
......@@ -244,9 +244,16 @@ def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
if args.video_option:
rgb_frames = [rgb_frames[i] for i in state_index]
episode_reward_mean = (episode_rewards / episode_counts).mean().item()
episode_spl_mean = (episode_spls / episode_counts).mean().item()
episode_success_mean = (episode_success / episode_counts).mean().item()
aggregated_stats = dict()
for stat_key in next(iter(stats_episodes.values())).keys():
aggregated_stats[stat_key] = sum(
[v[stat_key] for v in stats_episodes.values()]
)
num_episodes = len(stats_episodes)
episode_reward_mean = aggregated_stats["reward"] / num_episodes
episode_spl_mean = aggregated_stats["spl"] / num_episodes
episode_success_mean = aggregated_stats["success"] / num_episodes
logger.info("Average episode reward: {:.6f}".format(episode_reward_mean))
logger.info("Average episode success: {:.6f}".format(episode_success_mean))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment