Skip to content
Snippets Groups Projects
Commit fbdca8ba authored by Erik Wijmans's avatar Erik Wijmans Committed by Abhishek Kadian
Browse files

Fix resume_all() (#88)

Close paused workers
parent e703e610
No related branches found
No related tags found
No related merge requests found
...@@ -315,10 +315,19 @@ class VectorEnv: ...@@ -315,10 +315,19 @@ class VectorEnv:
if self._is_waiting: if self._is_waiting:
for read_fn in self._connection_read_fns: for read_fn in self._connection_read_fns:
read_fn() read_fn()
for write_fn in self._connection_write_fns: for write_fn in self._connection_write_fns:
write_fn((CLOSE_COMMAND, None)) write_fn((CLOSE_COMMAND, None))
for _, _, write_fn, _ in self._paused:
write_fn((CLOSE_COMMAND, None))
for process in self._workers: for process in self._workers:
process.join() process.join()
for _, _, _, process in self._paused:
process.join()
self._is_closed = True self._is_closed = True
def pause_at(self, index: int) -> None: def pause_at(self, index: int) -> None:
...@@ -342,7 +351,7 @@ class VectorEnv: ...@@ -342,7 +351,7 @@ class VectorEnv:
def resume_all(self) -> None: def resume_all(self) -> None:
"""Resumes any paused envs. """Resumes any paused envs.
""" """
for index, read_fn, write_fn, worker in self._paused: for index, read_fn, write_fn, worker in reversed(self._paused):
self._connection_read_fns.insert(index, read_fn) self._connection_read_fns.insert(index, read_fn)
self._connection_write_fns.insert(index, write_fn) self._connection_write_fns.insert(index, write_fn)
self._workers.insert(index, worker) self._workers.insert(index, worker)
......
...@@ -309,8 +309,8 @@ def test_vec_env_call_func(): ...@@ -309,8 +309,8 @@ def test_vec_env_call_func():
env_ids = envs.call(["get_env_ind"] * num_envs) env_ids = envs.call(["get_env_ind"] * num_envs)
assert env_ids == true_env_ids assert env_ids == true_env_ids
envs.pause_at(3) envs.pause_at(0)
true_env_ids.pop(3) true_env_ids.pop(0)
env_ids = envs.call(["get_env_ind"] * num_envs) env_ids = envs.call(["get_env_ind"] * num_envs)
assert env_ids == true_env_ids assert env_ids == true_env_ids
...@@ -325,6 +325,21 @@ def test_vec_env_call_func(): ...@@ -325,6 +325,21 @@ def test_vec_env_call_func():
envs.close() envs.close()
def test_close_with_paused():
configs, datasets = _load_test_data()
num_envs = len(configs)
env_fn_args = tuple(zip(configs, datasets, range(num_envs)))
with habitat.VectorEnv(
env_fn_args=env_fn_args, multiprocessing_start_method="forkserver"
) as envs:
envs.reset()
envs.pause_at(3)
envs.pause_at(0)
assert envs._is_closed
# TODO Bring back this test for the greedy follower # TODO Bring back this test for the greedy follower
@pytest.mark.skip @pytest.mark.skip
def test_action_space_shortest_path(): def test_action_space_shortest_path():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment