diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py
index 7e02927e2d92217b0b5e67b717eb2c6d9906471c..a7841dc07dddb3af2b23566e412853278ee79b9f 100644
--- a/habitat/core/dataset.py
+++ b/habitat/core/dataset.py
@@ -128,10 +128,11 @@ class Dataset(Generic[T]):
     def get_splits(
         self,
         num_splits: int,
-        max_episodes_per_split: Optional[int] = None,
+        episodes_per_split: Optional[int] = None,
         remove_unused_episodes: bool = False,
         collate_scene_ids: bool = True,
         sort_by_episode_id: bool = False,
+        allow_uneven_splits: bool = False,
     ) -> List["Dataset"]:
         """
         Returns a list of new datasets, each with a subset of the original
@@ -139,7 +140,7 @@ class Dataset(Generic[T]):
         episodes will be duplicated.
         Args:
             num_splits: The number of splits to create.
-            max_episodes_per_split: If provided, each split will have up to
+            episodes_per_split: If provided, each split will have up to
                 this many episodes. If it is not provided, each dataset will
                 have len(original_dataset.episodes) // num_splits episodes. If
                 max_episodes_per_split is provided and is larger than this
@@ -153,24 +154,42 @@ class Dataset(Generic[T]):
                 to each other because they will be in the same scene.
             sort_by_episode_id: If true, sequences are sorted by their episode
                 ID in the returned splits.
+            allow_uneven_splits: If true, the last split can be shorter than
+                the others. This is especially useful for splitting over
+                validation/test datasets in order to make sure that all
+                episodes are copied but none are duplicated.
         Returns:
             A list of new datasets, each with their own subset of episodes.
         """
-
         assert (
             len(self.episodes) >= num_splits
         ), "Not enough episodes to create this many splits."
+        if episodes_per_split is not None:
+            assert not allow_uneven_splits, (
+                "You probably don't want to specify allow_uneven_splits"
+                " and episodes_per_split."
+            )
+            assert num_splits * episodes_per_split <= len(self.episodes)
 
         new_datasets = []
-        if max_episodes_per_split is None:
-            max_episodes_per_split = len(self.episodes) // num_splits
-        max_episodes_per_split = min(
-            max_episodes_per_split, (len(self.episodes) // num_splits)
-        )
+
+        if allow_uneven_splits:
+            stride = int(np.ceil(len(self.episodes) * 1.0 / num_splits))
+            split_lengths = [stride] * (num_splits - 1)
+            split_lengths.append(
+                (len(self.episodes) - stride * (num_splits - 1))
+            )
+        else:
+            if episodes_per_split is not None:
+                stride = episodes_per_split
+            else:
+                stride = len(self.episodes) // num_splits
+            split_lengths = [stride] * num_splits
+
+        num_episodes = sum(split_lengths)
+
         rand_items = np.random.choice(
-            len(self.episodes),
-            num_splits * max_episodes_per_split,
-            replace=False,
+            len(self.episodes), num_episodes, replace=False
         )
         if collate_scene_ids:
             scene_ids = {}
@@ -187,7 +206,7 @@ class Dataset(Generic[T]):
             new_dataset = copy.copy(self)  # Creates a shallow copy
             new_dataset.episodes = []
             new_datasets.append(new_dataset)
-            for ii in range(max_episodes_per_split):
+            for ii in range(split_lengths[nn]):
                 new_dataset.episodes.append(self.episodes[rand_items[ep_ind]])
                 ep_ind += 1
             if sort_by_episode_id:
@@ -196,35 +215,3 @@ class Dataset(Generic[T]):
         if remove_unused_episodes:
             self.episodes = new_episodes
         return new_datasets
-
-    def get_uneven_splits(self, num_splits):
-        """
-        Returns a list of new datasets, each with a subset of the original
-        episodes. The last dataset may have fewer episodes than the others.
-        This is especially useful for splitting over validation/test datasets
-        in order to make sure that all episodes are copied but none are
-        duplicated.
-        Args:
-            num_splits: The number of splits to create.
-        Returns:
-            A list of new datasets, each with their own subset of episodes.
-        """
-        assert (
-            len(self.episodes) >= num_splits
-        ), "Not enough episodes to create this many splits."
-        new_datasets = []
-        num_episodes = len(self.episodes)
-        stride = int(np.ceil(num_episodes * 1.0 / num_splits))
-        for ii, split in enumerate(
-            range(0, num_episodes, stride)[:num_splits]
-        ):
-            new_dataset = copy.copy(self)  # Creates a shallow copy
-            new_dataset.episodes = new_dataset.episodes[
-                split : min(split + stride, num_episodes)
-            ].copy()
-            new_datasets.append(new_dataset)
-        assert (
-            sum([len(new_dataset.episodes) for new_dataset in new_datasets])
-            == num_episodes
-        )
-        return new_datasets
diff --git a/test/test_dataset.py b/test/test_dataset.py
index 69b2bf69f83ae38d1460b9934753f2481767eb9d..975f6ada66300e81ac41ed6b5b5cbe315d219339 100644
--- a/test/test_dataset.py
+++ b/test/test_dataset.py
@@ -64,7 +64,7 @@ def test_get_splits_with_remainder():
         assert len(split.episodes) == 9
 
 
-def test_get_splits_max_episodes_specified():
+def test_get_splits_num_episodes_specified():
     dataset = _construct_dataset(100)
     splits = dataset.get_splits(10, 3, False)
     assert len(splits) == 10
@@ -73,7 +73,7 @@ def test_get_splits_max_episodes_specified():
     assert len(dataset.episodes) == 100
 
     dataset = _construct_dataset(100)
-    splits = dataset.get_splits(10, 11, False)
+    splits = dataset.get_splits(10, 10)
     assert len(splits) == 10
     for split in splits:
         assert len(split.episodes) == 10
@@ -86,6 +86,13 @@ def test_get_splits_max_episodes_specified():
         assert len(split.episodes) == 3
     assert len(dataset.episodes) == 30
 
+    dataset = _construct_dataset(100)
+    try:
+        splits = dataset.get_splits(10, 20)
+        assert False
+    except AssertionError:
+        pass
+
 
 def test_get_splits_collate_scenes():
     dataset = _construct_dataset(10000)
@@ -155,7 +162,17 @@ def test_get_splits_sort_by_episode_id():
 
 
 def test_get_uneven_splits():
-    dataset = _construct_dataset(100)
-    splits = dataset.get_uneven_splits(9)
+    dataset = _construct_dataset(10000)
+    splits = dataset.get_splits(9, allow_uneven_splits=False)
+    assert len(splits) == 9
+    assert sum([len(split.episodes) for split in splits]) == (10000 // 9) * 9
+
+    dataset = _construct_dataset(10000)
+    splits = dataset.get_splits(9, allow_uneven_splits=True)
     assert len(splits) == 9
-    assert sum([len(split.episodes) for split in splits]) == 100
+    assert sum([len(split.episodes) for split in splits]) == 10000
+
+    dataset = _construct_dataset(10000)
+    splits = dataset.get_splits(10, allow_uneven_splits=True)
+    assert len(splits) == 10
+    assert sum([len(split.episodes) for split in splits]) == 10000