From fbdca8ba67d1451ea89d54f0fb16dda9e6d9a68a Mon Sep 17 00:00:00 2001 From: Erik Wijmans <etw@gatech.edu> Date: Tue, 21 May 2019 16:37:15 -0700 Subject: [PATCH] Fix resume_all() (#88) Close paused workers --- habitat/core/vector_env.py | 11 ++++++++++- test/test_habitat_env.py | 19 +++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/habitat/core/vector_env.py b/habitat/core/vector_env.py index 67c405a4d..c956d9f07 100644 --- a/habitat/core/vector_env.py +++ b/habitat/core/vector_env.py @@ -315,10 +315,19 @@ class VectorEnv: if self._is_waiting: for read_fn in self._connection_read_fns: read_fn() + for write_fn in self._connection_write_fns: write_fn((CLOSE_COMMAND, None)) + + for _, _, write_fn, _ in self._paused: + write_fn((CLOSE_COMMAND, None)) + for process in self._workers: process.join() + + for _, _, _, process in self._paused: + process.join() + self._is_closed = True def pause_at(self, index: int) -> None: @@ -342,7 +351,7 @@ class VectorEnv: def resume_all(self) -> None: """Resumes any paused envs. """ - for index, read_fn, write_fn, worker in self._paused: + for index, read_fn, write_fn, worker in reversed(self._paused): self._connection_read_fns.insert(index, read_fn) self._connection_write_fns.insert(index, write_fn) self._workers.insert(index, worker) diff --git a/test/test_habitat_env.py b/test/test_habitat_env.py index 5c1c5961e..274b5c0e0 100644 --- a/test/test_habitat_env.py +++ b/test/test_habitat_env.py @@ -309,8 +309,8 @@ def test_vec_env_call_func(): env_ids = envs.call(["get_env_ind"] * num_envs) assert env_ids == true_env_ids - envs.pause_at(3) - true_env_ids.pop(3) + envs.pause_at(0) + true_env_ids.pop(0) env_ids = envs.call(["get_env_ind"] * num_envs) assert env_ids == true_env_ids @@ -325,6 +325,21 @@ def test_vec_env_call_func(): envs.close() +def test_close_with_paused(): + configs, datasets = _load_test_data() + num_envs = len(configs) + env_fn_args = tuple(zip(configs, datasets, range(num_envs))) + with habitat.VectorEnv( + env_fn_args=env_fn_args, multiprocessing_start_method="forkserver" + ) as envs: + envs.reset() + + envs.pause_at(3) + envs.pause_at(0) + + assert envs._is_closed + + # TODO Bring back this test for the greedy follower @pytest.mark.skip def test_action_space_shortest_path(): -- GitLab