Skip to content
Snippets Groups Projects
Commit 04368cdc authored by JasonJiazhiZhang's avatar JasonJiazhiZhang Committed by Oleksandr
Browse files

Generalize performance metric for baselines (#178)

Making eval() in ppo_trainer retrieve the name (sensor_uuid) of performance metric used instead of always using spl. This change is necessary for ppo_trainer to support new tasks that might have performance metric other than spl.
parent dd75f18f
No related branches found
No related tags found
No related merge requests found
...@@ -120,7 +120,8 @@ def generate_video( ...@@ -120,7 +120,8 @@ def generate_video(
images: List[np.ndarray], images: List[np.ndarray],
episode_id: int, episode_id: int,
checkpoint_idx: int, checkpoint_idx: int,
spl: float, metric_name: str,
metric_value: float,
tb_writer: TensorboardWriter, tb_writer: TensorboardWriter,
fps: int = 10, fps: int = 10,
) -> None: ) -> None:
...@@ -132,7 +133,8 @@ def generate_video( ...@@ -132,7 +133,8 @@ def generate_video(
images: list of images to be converted to video. images: list of images to be converted to video.
episode_id: episode id for video naming. episode_id: episode id for video naming.
checkpoint_idx: checkpoint index for video naming. checkpoint_idx: checkpoint index for video naming.
spl: SPL for this episode for video naming. metric_name: name of the performance metric, e.g. "spl".
metric_value: value of metric.
tb_writer: tensorboard writer object for uploading video. tb_writer: tensorboard writer object for uploading video.
fps: fps for generated video. fps: fps for generated video.
Returns: Returns:
...@@ -141,7 +143,7 @@ def generate_video( ...@@ -141,7 +143,7 @@ def generate_video(
if len(images) < 1: if len(images) < 1:
return return
video_name = f"episode{episode_id}_ckpt{checkpoint_idx}_spl{spl:.2f}" video_name = f"episode{episode_id}_ckpt{checkpoint_idx}_{metric_name}{metric_value:.2f}"
if "disk" in video_option: if "disk" in video_option:
assert video_dir is not None assert video_dir is not None
images_to_video(images, video_dir, video_name) images_to_video(images, video_dir, video_name)
......
...@@ -379,6 +379,15 @@ class PPOTrainer(BaseRLTrainer): ...@@ -379,6 +379,15 @@ class PPOTrainer(BaseRLTrainer):
self.agent.load_state_dict(ckpt_dict["state_dict"]) self.agent.load_state_dict(ckpt_dict["state_dict"])
self.actor_critic = self.agent.actor_critic self.actor_critic = self.agent.actor_critic
# get name of performance metric, e.g. "spl"
metric_name = self.config.TASK_CONFIG.TASK.MEASUREMENTS[0]
metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name)
measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
assert measure_type is not None, "invalid measurement type {}".format(
metric_cfg.TYPE
)
self.metric_uuid = measure_type(None, None)._get_uuid()
observations = self.envs.reset() observations = self.envs.reset()
batch = batch_obs(observations) batch = batch_obs(observations)
for sensor in batch: for sensor in batch:
...@@ -457,8 +466,12 @@ class PPOTrainer(BaseRLTrainer): ...@@ -457,8 +466,12 @@ class PPOTrainer(BaseRLTrainer):
# episode ended # episode ended
if not_done_masks[i].item() == 0: if not_done_masks[i].item() == 0:
episode_stats = dict() episode_stats = dict()
episode_stats["spl"] = infos[i]["spl"] episode_stats[self.metric_uuid] = infos[i][
episode_stats["success"] = int(infos[i]["spl"] > 0) self.metric_uuid
]
episode_stats["success"] = int(
infos[i][self.metric_uuid] > 0
)
episode_stats["reward"] = current_episode_reward[i].item() episode_stats["reward"] = current_episode_reward[i].item()
current_episode_reward[i] = 0 current_episode_reward[i] = 0
# use scene_id + episode_id as unique id for storing stats # use scene_id + episode_id as unique id for storing stats
...@@ -476,7 +489,8 @@ class PPOTrainer(BaseRLTrainer): ...@@ -476,7 +489,8 @@ class PPOTrainer(BaseRLTrainer):
images=rgb_frames[i], images=rgb_frames[i],
episode_id=current_episodes[i].episode_id, episode_id=current_episodes[i].episode_id,
checkpoint_idx=checkpoint_index, checkpoint_idx=checkpoint_index,
spl=infos[i]["spl"], metric_name=self.metric_uuid,
metric_value=infos[i][self.metric_uuid],
tb_writer=writer, tb_writer=writer,
) )
...@@ -516,12 +530,14 @@ class PPOTrainer(BaseRLTrainer): ...@@ -516,12 +530,14 @@ class PPOTrainer(BaseRLTrainer):
num_episodes = len(stats_episodes) num_episodes = len(stats_episodes)
episode_reward_mean = aggregated_stats["reward"] / num_episodes episode_reward_mean = aggregated_stats["reward"] / num_episodes
episode_spl_mean = aggregated_stats["spl"] / num_episodes episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes
episode_success_mean = aggregated_stats["success"] / num_episodes episode_success_mean = aggregated_stats["success"] / num_episodes
logger.info(f"Average episode reward: {episode_reward_mean:.6f}") logger.info(f"Average episode reward: {episode_reward_mean:.6f}")
logger.info(f"Average episode success: {episode_success_mean:.6f}") logger.info(f"Average episode success: {episode_success_mean:.6f}")
logger.info(f"Average episode SPL: {episode_spl_mean:.6f}") logger.info(
f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}"
)
writer.add_scalars( writer.add_scalars(
"eval_reward", "eval_reward",
...@@ -529,7 +545,9 @@ class PPOTrainer(BaseRLTrainer): ...@@ -529,7 +545,9 @@ class PPOTrainer(BaseRLTrainer):
checkpoint_index, checkpoint_index,
) )
writer.add_scalars( writer.add_scalars(
"eval_SPL", {"average SPL": episode_spl_mean}, checkpoint_index f"eval_{self.metric_uuid}",
{f"average {self.metric_uuid}": episode_metric_mean},
checkpoint_index,
) )
writer.add_scalars( writer.add_scalars(
"eval_success", "eval_success",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment