diff --git a/test/test_habitat_env.py b/test/test_habitat_env.py index 628e94c0c85f22b0a92f0d3d94bce9f9f8e30b4b..5c1c5961efd5973eba2be3bc44678baf15c72281 100644 --- a/test/test_habitat_env.py +++ b/test/test_habitat_env.py @@ -221,9 +221,17 @@ def test_rl_vectorized_envs(): assert len(rewards) == num_envs assert len(dones) == num_envs assert len(infos) == num_envs - assert envs.render( - mode="rgb_array" - ).all(), "vector env render is broken" + + tiled_img = envs.render(mode="rgb_array") + new_height = int(np.ceil(np.sqrt(NUM_ENVS))) + new_width = int(np.ceil(float(NUM_ENVS) / new_height)) + h, w, c = observations[0]["rgb"].shape + assert tiled_img.shape == ( + h * new_height, + w * new_width, + c, + ), "vector env render is broken" + if (i + 1) % configs[0].ENVIRONMENT.MAX_EPISODE_STEPS == 0: assert all(dones), "dones should be true after max_episode steps"