diff --git a/habitat/core/vector_env.py b/habitat/core/vector_env.py index 4618e77c7e6748852047eaf4c8a7a6f5aacf043e..124918264c540d50855c8eaf0637f4c8d668bc4f 100644 --- a/habitat/core/vector_env.py +++ b/habitat/core/vector_env.py @@ -17,6 +17,7 @@ from habitat.config import Config from habitat.core.env import Env, Observations from habitat.core.logging import logger from habitat.core.utils import tile_images +import gym STEP_COMMAND = "step" RESET_COMMAND = "reset" @@ -24,6 +25,7 @@ RENDER_COMMAND = "render" CLOSE_COMMAND = "close" OBSERVATION_SPACE_COMMAND = "observation_space" ACTION_SPACE_COMMAND = "action_space" +CALL_COMMAND = "call" def _make_env_fn( @@ -116,6 +118,7 @@ class VectorEnv: self.action_spaces = [ read_fn() for read_fn in self._connection_read_fns ] + self._paused = [] @property def num_envs(self): @@ -145,7 +148,9 @@ class VectorEnv: while command != CLOSE_COMMAND: if command == STEP_COMMAND: # different step methods for habitat.RLEnv and habitat.Env - if isinstance(env, habitat.RLEnv): + if isinstance(env, habitat.RLEnv) or isinstance( + env, gym.Env + ): # habitat.RLEnv observations, reward, done, info = env.step(data) if auto_reset_done and done: @@ -173,6 +178,13 @@ class VectorEnv: ): connection_write_fn(getattr(env, command)) + elif command == CALL_COMMAND: + function_name, function_args = data + if function_args is None or len(function_args) == 0: + result = getattr(env, function_name)() + else: + result = getattr(env, function_name)(*function_args) + connection_write_fn(result) else: raise NotImplementedError @@ -307,9 +319,90 @@ class VectorEnv: write_fn((CLOSE_COMMAND, None)) for process in self._workers: process.join() - self._is_closed = True + def pause_at(self, index: int) -> None: + """Pauses computation on this env without destroying the env. This is + useful for not needing to call steps on all environments when only + some are active (for example during the last episodes of running + eval episodes). + + Args: + index: which env to pause. All indexes after this one will be + shifted down by one. + """ + if self._is_waiting: + for read_fn in self._connection_read_fns: + read_fn() + read_fn = self._connection_read_fns.pop(index) + write_fn = self._connection_write_fns.pop(index) + worker = self._workers.pop(index) + self._paused.append((index, read_fn, write_fn, worker)) + + def resume_all(self) -> None: + """Resumes any paused envs. + """ + for index, read_fn, write_fn, worker in self._paused: + self._connection_read_fns.insert(index, read_fn) + self._connection_write_fns.insert(index, write_fn) + self._workers.insert(index, worker) + self._paused = [] + + def call_at( + self, + index: int, + function_name: str, + function_args: Optional[List[Any]] = None, + ) -> Any: + """Calls a function (which is passed by name) on the selected env and + returns the result. + + Args: + index: Which env to call the function on. + function_name: The name of the function to call on the env. + function_args: Optional function args. + Returns: + A The result of calling the function. + """ + self._is_waiting = True + self._connection_write_fns[index]( + (CALL_COMMAND, (function_name, function_args)) + ) + result = self._connection_read_fns[index]() + self._is_waiting = False + return result + + def call( + self, + function_names: List[str], + function_args_list: Optional[List[Any]] = None, + ) -> List[Any]: + """Calls a list of functions (which are passed by name) on the + corresponding env (by index). + + Args: + function_names: The name of the functions to call on the envs. + function_args_list: List of function args for each function. If + provided, len(function_args_list) should be as long as + len(function_names). + Returns: + A The result of calling the function. + """ + self._is_waiting = True + if function_args_list is None: + function_args_list = [None] * len(function_names) + assert len(function_names) == len(function_args_list) + func_args = zip(function_names, function_args_list) + for write_fn, func_args_on in zip( + self._connection_write_fns, func_args + ): + write_fn((CALL_COMMAND, func_args_on)) + results = [] + for read_fn in self._connection_read_fns: + results.append(read_fn()) + self._is_waiting = False + return results + def render( self, mode: str = "human", *args, **kwargs ) -> Union[np.ndarray, None]: diff --git a/test/test_habitat_env.py b/test/test_habitat_env.py index 2094d5b13c20503f4b091adaa9c66ea3e46305bc..bed5f6f764d47baedbcd04e738b940082bb28294 100644 --- a/test/test_habitat_env.py +++ b/test/test_habitat_env.py @@ -18,10 +18,14 @@ from habitat.sims.habitat_simulator import SimulatorActions from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal CFG_TEST = "test/habitat_all_sensors_test.yaml" -NUM_ENVS = 2 +NUM_ENVS = 4 class DummyRLEnv(habitat.RLEnv): + def __init__(self, config, dataset=None, env_ind=0): + super(DummyRLEnv, self).__init__(config, dataset) + self._env_ind = env_ind + def get_reward_range(self): return -1.0, 1.0 @@ -37,6 +41,12 @@ class DummyRLEnv(habitat.RLEnv): def get_info(self, observations): return {} + def get_env_ind(self): + return self._env_ind + + def set_env_ind(self, new_env_ind): + self._env_ind = new_env_ind + def _load_test_data(): configs = [] @@ -257,6 +267,53 @@ def test_rl_env(): env.close() +def _make_dummy_env_func(config, dataset, id): + return DummyRLEnv(config=config, dataset=dataset, env_ind=id) + + +def test_vec_env_call_func(): + configs, datasets = _load_test_data() + num_envs = len(configs) + env_fn_args = tuple(zip(configs, datasets, range(num_envs))) + true_env_ids = list(range(num_envs)) + envs = habitat.VectorEnv( + make_env_fn=_make_dummy_env_func, + env_fn_args=env_fn_args, + multiprocessing_start_method="forkserver", + ) + envs.reset() + env_ids = envs.call(["get_env_ind"] * num_envs) + assert env_ids == true_env_ids + + env_id = envs.call_at(1, "get_env_ind") + assert env_id == true_env_ids[1] + + envs.call_at(2, "set_env_ind", [20]) + true_env_ids[2] = 20 + env_ids = envs.call(["get_env_ind"] * num_envs) + assert env_ids == true_env_ids + + envs.call_at(2, "set_env_ind", [2]) + true_env_ids[2] = 2 + env_ids = envs.call(["get_env_ind"] * num_envs) + assert env_ids == true_env_ids + + envs.pause_at(3) + true_env_ids.pop(3) + env_ids = envs.call(["get_env_ind"] * num_envs) + assert env_ids == true_env_ids + + envs.pause_at(0) + true_env_ids.pop(0) + env_ids = envs.call(["get_env_ind"] * num_envs) + assert env_ids == true_env_ids + + envs.resume_all() + env_ids = envs.call(["get_env_ind"] * num_envs) + assert env_ids == list(range(num_envs)) + envs.close() + + # TODO Bring back this test for the greedy follower @pytest.mark.skip def test_action_space_shortest_path():