diff --git a/habitat/core/vector_env.py b/habitat/core/vector_env.py index 67c405a4d214fc22b4725220f0040bc43d31aaf8..c956d9f07138cafe480a9fc2a39a961da0e28a35 100644 --- a/habitat/core/vector_env.py +++ b/habitat/core/vector_env.py @@ -315,10 +315,19 @@ class VectorEnv: if self._is_waiting: for read_fn in self._connection_read_fns: read_fn() + for write_fn in self._connection_write_fns: write_fn((CLOSE_COMMAND, None)) + + for _, _, write_fn, _ in self._paused: + write_fn((CLOSE_COMMAND, None)) + for process in self._workers: process.join() + + for _, _, _, process in self._paused: + process.join() + self._is_closed = True def pause_at(self, index: int) -> None: @@ -342,7 +351,7 @@ class VectorEnv: def resume_all(self) -> None: """Resumes any paused envs. """ - for index, read_fn, write_fn, worker in self._paused: + for index, read_fn, write_fn, worker in reversed(self._paused): self._connection_read_fns.insert(index, read_fn) self._connection_write_fns.insert(index, write_fn) self._workers.insert(index, worker) diff --git a/test/test_habitat_env.py b/test/test_habitat_env.py index 5c1c5961efd5973eba2be3bc44678baf15c72281..274b5c0e0ae36cc5d6048fada9e13ebfece20889 100644 --- a/test/test_habitat_env.py +++ b/test/test_habitat_env.py @@ -309,8 +309,8 @@ def test_vec_env_call_func(): env_ids = envs.call(["get_env_ind"] * num_envs) assert env_ids == true_env_ids - envs.pause_at(3) - true_env_ids.pop(3) + envs.pause_at(0) + true_env_ids.pop(0) env_ids = envs.call(["get_env_ind"] * num_envs) assert env_ids == true_env_ids @@ -325,6 +325,21 @@ def test_vec_env_call_func(): envs.close() +def test_close_with_paused(): + configs, datasets = _load_test_data() + num_envs = len(configs) + env_fn_args = tuple(zip(configs, datasets, range(num_envs))) + with habitat.VectorEnv( + env_fn_args=env_fn_args, multiprocessing_start_method="forkserver" + ) as envs: + envs.reset() + + envs.pause_at(3) + envs.pause_at(0) + + assert envs._is_closed + + # TODO Bring back this test for the greedy follower @pytest.mark.skip def test_action_space_shortest_path():