diff --git a/habitat/config/default.py b/habitat/config/default.py
index 8fd60e74baaa3f7ce6b2bbabfa2ec5c120d812f6..a433e997ccaef07e1f2d8bdd93e77504c145d32e 100644
--- a/habitat/config/default.py
+++ b/habitat/config/default.py
@@ -22,6 +22,12 @@ _C.SEED = 100
 _C.ENVIRONMENT = CN()
 _C.ENVIRONMENT.MAX_EPISODE_STEPS = 1000
 _C.ENVIRONMENT.MAX_EPISODE_SECONDS = 10000000
+_C.ENVIRONMENT.ITERATOR_OPTIONS = CN()
+_C.ENVIRONMENT.ITERATOR_OPTIONS.CYCLE = True
+_C.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
+_C.ENVIRONMENT.ITERATOR_OPTIONS.GROUP_BY_SCENE = True
+_C.ENVIRONMENT.ITERATOR_OPTIONS.NUM_EPISODE_SAMPLE = -1
+_C.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT = -1
 # -----------------------------------------------------------------------------
 # TASK
 # -----------------------------------------------------------------------------
@@ -150,7 +156,6 @@ _C.DATASET = CN()
 _C.DATASET.TYPE = "PointNav-v1"
 _C.DATASET.SPLIT = "train"
 _C.DATASET.SCENES_DIR = "data/scene_datasets"
-_C.DATASET.NUM_EPISODE_SAMPLE = -1
 _C.DATASET.CONTENT_SCENES = ["*"]
 _C.DATASET.DATA_PATH = (
     "data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py
index b0bc35ae4d924384f5d545ce6aab400e3cc43fda..c94736ea39398e2a4a723c30f0ed9188a5c5b479 100644
--- a/habitat/core/dataset.py
+++ b/habitat/core/dataset.py
@@ -10,8 +10,18 @@ of a ``habitat.Agent`` inside ``habitat.Env``.
 """
 import copy
 import json
-from itertools import cycle
-from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar
+import random
+from itertools import groupby
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generic,
+    Iterator,
+    List,
+    Optional,
+    TypeVar,
+)
 
 import attr
 import numpy as np
@@ -48,16 +58,12 @@ class Episode:
     info: Optional[Dict[str, str]] = None
 
 
-T = TypeVar("T", Episode, Type[Episode])
+T = TypeVar("T", bound=Episode)
 
 
 class Dataset(Generic[T]):
     r"""Base class for dataset specification.
-
-    Attributes:
-        episodes: list of episodes containing instance information.
     """
-
     episodes: List[T]
 
     @property
@@ -90,15 +96,20 @@ class Dataset(Generic[T]):
         """
         return [self.episodes[episode_id] for episode_id in indexes]
 
-    def get_episode_iterator(self):
-        r"""
-        Creates and returns an iterator that iterates through self.episodes
-        in the desirable way specified.
+    def get_episode_iterator(self, *args: Any, **kwargs: Any) -> Iterator:
+        r"""Gets episode iterator with options. Options are specified in
+        EpisodeIterator documentation. To further customize iterator behavior
+        for your Dataset subclass, create a customized iterator class like
+        EpisodeIterator and override this method.
+
+        Args:
+            *args: positional args for iterator constructor
+            **kwargs: keyword args for iterator constructor
+
         Returns:
-            iterator for episodes
+            Iterator: episode iterator with specified behavior
         """
-        # TODO: support shuffling between epoch and  scene switching
-        return cycle(self.episodes)
+        return EpisodeIterator(self.episodes, *args, **kwargs)
 
     def to_json(self) -> str:
         class DatasetJSONEncoder(json.JSONEncoder):
@@ -111,8 +122,7 @@ class Dataset(Generic[T]):
     def from_json(
         self, json_str: str, scenes_dir: Optional[str] = None
     ) -> None:
-        r"""
-        Creates dataset from ``json_str``. Directory containing relevant 
+        r"""Creates dataset from ``json_str``. Directory containing relevant
         graphical assets of scenes is passed through ``scenes_dir``.
 
         Args:
@@ -122,11 +132,8 @@ class Dataset(Generic[T]):
         """
         raise NotImplementedError
 
-    def filter_episodes(
-        self, filter_fn: Callable[[Episode], bool]
-    ) -> "Dataset":
-        r"""
-        Returns a new dataset with only the filtered episodes from the 
+    def filter_episodes(self, filter_fn: Callable[[T], bool]) -> "Dataset":
+        r"""Returns a new dataset with only the filtered episodes from the
         original dataset.
 
         Args:
@@ -236,20 +243,118 @@ class Dataset(Generic[T]):
             self.episodes = new_episodes
         return new_datasets
 
-    def sample_episodes(self, num_episodes: int) -> None:
-        """
-        Sample from existing episodes a list of episodes of size num_episodes,
-        and replace self.episodes with the list of sampled episodes.
+
+class EpisodeIterator(Iterator):
+    r"""Episode Iterator class that gives options for how a list of episodes
+    should be iterated. Some of those options are desirable for the internal
+    simulator to get higher performance. More context: simulator suffers
+    overhead when switching between scenes, therefore episodes of the same
+    scene should be loaded consecutively. However, if too many consecutive
+    episodes from same scene are feed into RL model, the model will risk to
+    overfit that scene. Therefore it's better to load same scene consecutively
+    and switch once a number threshold is reached.
+
+    Currently supports the following features:
+        Cycling: when all episodes are iterated, cycle back to start instead of
+            throwing StopIteration.
+        Cycling with shuffle: when cycling back, shuffle episodes groups
+            grouped by scene.
+        Group by scene: episodes of same scene will be grouped and loaded
+            consecutively.
+        Set max scene repeat: set a number threshold on how many episodes from
+        the same scene can be loaded consecutively.
+        Sample episodes: sample the specified number of episodes.
+    """
+
+    def __init__(
+        self,
+        episodes: List[T],
+        cycle: bool = True,
+        shuffle: bool = False,
+        group_by_scene: bool = True,
+        max_scene_repeat: int = -1,
+        num_episode_sample: int = -1,
+    ):
+        r"""
         Args:
-            num_episodes: number of episodes to sample, input -1 to use
-            whole episodes
+            episodes: list of episodes.
+            cycle: if true, cycle back to first episodes when StopIteration.
+            shuffle: if true, shuffle scene groups when cycle.
+                No effect if cycle is set to false. Will shuffle grouped
+                scenes if group_by_scene is true.
+            group_by_scene: if true, group episodes from same scene.
+            max_scene_repeat: threshold of how many episodes from the same
+                scene can be loaded consecutively. -1 for no limit
+            num_episode_sample: number of episodes to be sampled.
+                -1 for no sampling.
         """
-        if num_episodes == -1:
-            return
-        if num_episodes < -1:
-            raise ValueError(
-                f"Invalid number for episodes to sample: {num_episodes}"
+        # sample episodes
+        if num_episode_sample >= 0:
+            episodes = np.random.choice(
+                episodes, num_episode_sample, replace=False
             )
-        self.episodes = np.random.choice(
-            self.episodes, num_episodes, replace=False
-        )
+        self.episodes = episodes
+        self.cycle = cycle
+        self.group_by_scene = group_by_scene
+        if group_by_scene:
+            num_scene_groups = len(
+                list(groupby(episodes, key=lambda x: x.scene_id))
+            )
+            num_unique_scenes = len(set([e.scene_id for e in episodes]))
+            if num_scene_groups >= num_unique_scenes:
+                self.episodes = sorted(self.episodes, key=lambda x: x.scene_id)
+        self.max_scene_repetition = max_scene_repeat
+        self.shuffle = shuffle
+        self._rep_count = 0
+        self._prev_scene_id = None
+        self._iterator = iter(self.episodes)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        r"""The main logic for handling how episodes will be iterated.
+
+        Returns:
+            next episode.
+        """
+
+        next_episode = next(self._iterator, None)
+        if next_episode is None:
+            if not self.cycle:
+                raise StopIteration
+            self._iterator = iter(self.episodes)
+            if self.shuffle:
+                self._shuffle_iterator()
+            next_episode = next(self._iterator)
+
+        if self._prev_scene_id == next_episode.scene_id:
+            self._rep_count += 1
+        if (
+            self.max_scene_repetition > 0
+            and self._rep_count >= self.max_scene_repetition - 1
+        ):
+            self._shuffle_iterator()
+            self._rep_count = 0
+
+        self._prev_scene_id = next_episode.scene_id
+        return next_episode
+
+    def _shuffle_iterator(self) -> None:
+        r"""Internal method that shuffles the remaining episodes.
+            If self.group_by_scene is true, then shuffle groups of scenes.
+
+        Returns:
+            None.
+        """
+        if self.group_by_scene:
+            grouped_episodes = [
+                list(g)
+                for k, g in groupby(self._iterator, key=lambda x: x.scene_id)
+            ]
+            random.shuffle(grouped_episodes)
+            self._iterator = iter(sum(grouped_episodes, []))
+        else:
+            episodes = list(self._iterator)
+            random.shuffle(episodes)
+            self._iterator = iter(episodes)
diff --git a/habitat/core/env.py b/habitat/core/env.py
index 926763306c928233107676e224214a353d6ddc88..d948ece59c98bfece847e720db83533fb6bc8b01 100644
--- a/habitat/core/env.py
+++ b/habitat/core/env.py
@@ -72,7 +72,13 @@ class Env:
             )
         self._episodes = self._dataset.episodes if self._dataset else []
         self._current_episode = None
-        self._episode_iterator = self._dataset.get_episode_iterator()
+        iter_option_dict = {
+            k.lower(): v
+            for k, v in config.ENVIRONMENT.ITERATOR_OPTIONS.items()
+        }
+        self._episode_iterator = self._dataset.get_episode_iterator(
+            **iter_option_dict
+        )
 
         # load the first scene if dataset is present
         if self._dataset:
diff --git a/habitat/datasets/eqa/mp3d_eqa_dataset.py b/habitat/datasets/eqa/mp3d_eqa_dataset.py
index 4194fb58aefe230abc2ca33ee4da515125a79ee0..f256a5164f32aa004b574875bf4b69c6d37a41fa 100644
--- a/habitat/datasets/eqa/mp3d_eqa_dataset.py
+++ b/habitat/datasets/eqa/mp3d_eqa_dataset.py
@@ -52,8 +52,6 @@ class Matterport3dDatasetV1(Dataset):
         with gzip.open(config.DATA_PATH.format(split=config.SPLIT), "rt") as f:
             self.from_json(f.read())
 
-        self.sample_episodes(config.NUM_EPISODE_SAMPLE)
-
     def from_json(
         self, json_str: str, scenes_dir: Optional[str] = None
     ) -> None:
diff --git a/habitat/datasets/pointnav/pointnav_dataset.py b/habitat/datasets/pointnav/pointnav_dataset.py
index 70caad5e9bb3078f9c86b5e01182dfdbb868ca84..9e9cdf516f6cd28cd5edd5d0f187bcb6921966aa 100644
--- a/habitat/datasets/pointnav/pointnav_dataset.py
+++ b/habitat/datasets/pointnav/pointnav_dataset.py
@@ -98,8 +98,6 @@ class PointNavDatasetV1(Dataset):
             with gzip.open(scene_filename, "rt") as f:
                 self.from_json(f.read(), scenes_dir=config.SCENES_DIR)
 
-        self.sample_episodes(config.NUM_EPISODE_SAMPLE)
-
     def from_json(
         self, json_str: str, scenes_dir: Optional[str] = None
     ) -> None:
diff --git a/test/test_dataset.py b/test/test_dataset.py
index 20792a264fbd041c19e123397805eb375883c031..ea5b03b712f68b39c1a2cd5b0af0dd68c4366165 100644
--- a/test/test_dataset.py
+++ b/test/test_dataset.py
@@ -4,17 +4,19 @@
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
+from itertools import groupby, islice
+
 import pytest
 
 from habitat.core.dataset import Dataset, Episode
 
 
-def _construct_dataset(num_episodes):
+def _construct_dataset(num_episodes, num_groups=10):
     episodes = []
-    for ii in range(num_episodes):
+    for i in range(num_episodes):
         episode = Episode(
-            episode_id=str(ii),
-            scene_id="scene_id_" + str(ii % 10),
+            episode_id=str(i),
+            scene_id="scene_id_" + str(i % num_groups),
             start_position=[0, 0, 0],
             start_rotation=[0, 0, 0, 1],
         )
@@ -181,30 +183,84 @@ def test_get_uneven_splits():
 
 
 def test_sample_episodes():
-    dataset = _construct_dataset(10000)
-    dataset.sample_episodes(-1)
-    assert len(dataset.episodes) == 10000
+    dataset = _construct_dataset(1000)
+    ep_iter = dataset.get_episode_iterator(
+        num_episode_sample=1000, cycle=False
+    )
+    assert len(list(ep_iter)) == 1000
 
-    dataset = _construct_dataset(10000)
-    dataset.sample_episodes(0)
-    assert len(dataset.episodes) == 0
+    ep_iter = dataset.get_episode_iterator(num_episode_sample=0, cycle=False)
+    assert len(list(ep_iter)) == 0
 
-    dataset = _construct_dataset(10000)
-    dataset.sample_episodes(1)
-    assert len(dataset.episodes) == 1
+    with pytest.raises(ValueError):
+        dataset.get_episode_iterator(num_episode_sample=1001, cycle=False)
 
-    dataset = _construct_dataset(10000)
-    dataset.sample_episodes(10000)
-    assert len(dataset.episodes) == 10000
+    ep_iter = dataset.get_episode_iterator(num_episode_sample=100, cycle=True)
+    ep_id_list = [e.episode_id for e in list(islice(ep_iter, 100))]
+    assert len(set(ep_id_list)) == 100
+    next_episode = next(ep_iter)
+    assert next_episode.episode_id in ep_id_list
 
-    dataset = _construct_dataset(10000)
-    with pytest.raises(Exception):
-        dataset.sample_episodes(10001)
+    ep_iter = dataset.get_episode_iterator(num_episode_sample=0, cycle=False)
+    with pytest.raises(StopIteration):
+        next(ep_iter)
 
 
-def test_iterator_looping():
+def test_iterator_cycle():
     dataset = _construct_dataset(100)
-    episode_iter = dataset.get_episode_iterator()
+    ep_iter = dataset.get_episode_iterator(
+        cycle=True, shuffle=False, group_by_scene=False
+    )
     for i in range(200):
-        episode = next(episode_iter)
+        episode = next(ep_iter)
         assert episode.episode_id == dataset.episodes[i % 100].episode_id
+
+    ep_iter = dataset.get_episode_iterator(cycle=True, num_episode_sample=20)
+    episodes = list(islice(ep_iter, 20))
+    for i in range(200):
+        episode = next(ep_iter)
+        assert episode.episode_id == episodes[i % 20].episode_id
+
+
+def test_iterator_shuffle():
+    dataset = _construct_dataset(100)
+    episode_iter = dataset.get_episode_iterator(shuffle=True)
+    first_round_episodes = list(islice(episode_iter, 100))
+    second_round_episodes = list(islice(episode_iter, 100))
+
+    # both rounds should have same episodes but in different order
+    assert sorted(first_round_episodes) == sorted(second_round_episodes)
+    assert first_round_episodes != second_round_episodes
+
+    # both rounds should be grouped by scenes
+    first_round_scene_groups = [
+        k for k, g in groupby(first_round_episodes, key=lambda x: x.scene_id)
+    ]
+    second_round_scene_groups = [
+        k for k, g in groupby(second_round_episodes, key=lambda x: x.scene_id)
+    ]
+    assert len(first_round_scene_groups) == len(second_round_scene_groups)
+    assert len(first_round_scene_groups) == len(set(first_round_scene_groups))
+
+
+def test_iterator_scene_switching():
+    total_ep = 1000
+    max_repeat = 25
+    dataset = _construct_dataset(total_ep)
+
+    episode_iter = dataset.get_episode_iterator(max_scene_repeat=max_repeat)
+    episodes = sorted(dataset.episodes, key=lambda x: x.scene_id)
+
+    # episodes before max_repeat reached should be identical
+    for i in range(max_repeat):
+        episode = next(episode_iter)
+        assert episode.episode_id == episodes.pop(0).episode_id
+
+    remaining_episodes = list(islice(episode_iter, total_ep - max_repeat))
+    # remaining episodes should be same but in different order
+    assert len(remaining_episodes) == len(episodes)
+    assert remaining_episodes != episodes
+    assert sorted(remaining_episodes) == sorted(episodes)
+
+    # next episodes should still be grouped by scene (before next switching)
+    assert len(set([e.scene_id for e in remaining_episodes[:max_repeat]])) == 1