Skip to content
Snippets Groups Projects
Unverified Commit 2729bf98 authored by Oleksandr's avatar Oleksandr Committed by GitHub
Browse files

Fixed bug with uneven splits for dataset. (#192)

* Fixed bug with uneven splits for dataset
* Fixed check if test data was loaded for CI
parent 5adbf5c3
No related branches found
No related tags found
No related merge requests found
...@@ -150,7 +150,7 @@ jobs: ...@@ -150,7 +150,7 @@ jobs:
- run: - run:
name: Download test data name: Download test data
command: | command: |
if [ ! -d ./habitat-sim/data/scene_datasets/habitat-test-scenes/van-gogh-room.glb ] if [ ! -f ./habitat-sim/data/scene_datasets/habitat-test-scenes/van-gogh-room.glb ]
then then
cd habitat-sim cd habitat-sim
wget http://dl.fbaipublicfiles.com/habitat/habitat-test-scenes.zip wget http://dl.fbaipublicfiles.com/habitat/habitat-test-scenes.zip
......
...@@ -66,6 +66,14 @@ class Dataset(Generic[T]): ...@@ -66,6 +66,14 @@ class Dataset(Generic[T]):
""" """
episodes: List[T] episodes: List[T]
@property
def num_episodes(self) -> int:
r"""
Returns:
number of episodes in the dataset.
"""
return len(self.episodes)
@property @property
def scene_ids(self) -> List[str]: def scene_ids(self) -> List[str]:
r""" r"""
...@@ -180,7 +188,7 @@ class Dataset(Generic[T]): ...@@ -180,7 +188,7 @@ class Dataset(Generic[T]):
same scene. same scene.
sort_by_episode_id: if true, sequences are sorted by their episode sort_by_episode_id: if true, sequences are sorted by their episode
ID in the returned splits. ID in the returned splits.
allow_uneven_splits: if true, the last split can be shorter than allow_uneven_splits: if true, the last splits can be shorter than
the others. This is especially useful for splitting over the others. This is especially useful for splitting over
validation/test datasets in order to make sure that all validation/test datasets in order to make sure that all
episodes are copied but none are duplicated. episodes are copied but none are duplicated.
...@@ -188,35 +196,40 @@ class Dataset(Generic[T]): ...@@ -188,35 +196,40 @@ class Dataset(Generic[T]):
Returns: Returns:
a list of new datasets, each with their own subset of episodes. a list of new datasets, each with their own subset of episodes.
""" """
assert ( if self.num_episodes < num_splits:
len(self.episodes) >= num_splits raise ValueError(
), "Not enough episodes to create this many splits." "Not enough episodes to create those many splits."
if episodes_per_split is not None:
assert not allow_uneven_splits, (
"You probably don't want to specify allow_uneven_splits"
" and episodes_per_split."
) )
assert num_splits * episodes_per_split <= len(self.episodes)
if episodes_per_split is not None:
if allow_uneven_splits:
raise ValueError(
"You probably don't want to specify allow_uneven_splits"
" and episodes_per_split."
)
if num_splits * episodes_per_split > self.num_episodes:
raise ValueError(
"Not enough episodes to create those many splits."
)
new_datasets = [] new_datasets = []
if allow_uneven_splits: if episodes_per_split is not None:
stride = int(np.ceil(len(self.episodes) * 1.0 / num_splits)) stride = episodes_per_split
split_lengths = [stride] * (num_splits - 1)
split_lengths.append(
(len(self.episodes) - stride * (num_splits - 1))
)
else: else:
if episodes_per_split is not None: stride = self.num_episodes // num_splits
stride = episodes_per_split split_lengths = [stride] * num_splits
else:
stride = len(self.episodes) // num_splits if allow_uneven_splits:
split_lengths = [stride] * num_splits episodes_left = self.num_episodes - stride * num_splits
split_lengths[:episodes_left] = [stride + 1] * episodes_left
assert sum(split_lengths) == self.num_episodes
num_episodes = sum(split_lengths) num_episodes = sum(split_lengths)
rand_items = np.random.choice( rand_items = np.random.choice(
len(self.episodes), num_episodes, replace=False self.num_episodes, num_episodes, replace=False
) )
if collate_scene_ids: if collate_scene_ids:
scene_ids = {} scene_ids = {}
......
...@@ -91,11 +91,8 @@ def test_get_splits_num_episodes_specified(): ...@@ -91,11 +91,8 @@ def test_get_splits_num_episodes_specified():
assert len(dataset.episodes) == 30 assert len(dataset.episodes) == 30
dataset = _construct_dataset(100) dataset = _construct_dataset(100)
try: with pytest.raises(ValueError):
splits = dataset.get_splits(10, 20) splits = dataset.get_splits(10, 20)
assert False
except AssertionError:
pass
def test_get_splits_collate_scenes(): def test_get_splits_collate_scenes():
...@@ -165,21 +162,21 @@ def test_get_splits_sort_by_episode_id(): ...@@ -165,21 +162,21 @@ def test_get_splits_sort_by_episode_id():
assert ep.episode_id >= split.episodes[ii - 1].episode_id assert ep.episode_id >= split.episodes[ii - 1].episode_id
def test_get_uneven_splits(): @pytest.mark.parametrize(
dataset = _construct_dataset(10000) "num_episodes,num_splits",
splits = dataset.get_splits(9, allow_uneven_splits=False) [(994, 64), (1023, 64), (1024, 64), (1025, 64), (10000, 9), (10000, 10)],
assert len(splits) == 9 )
assert sum([len(split.episodes) for split in splits]) == (10000 // 9) * 9 def test_get_splits_func(num_episodes: int, num_splits: int):
dataset = _construct_dataset(num_episodes)
dataset = _construct_dataset(10000) splits = dataset.get_splits(num_splits, allow_uneven_splits=True)
splits = dataset.get_splits(9, allow_uneven_splits=True) assert len(splits) == num_splits
assert len(splits) == 9 assert sum([len(split.episodes) for split in splits]) == num_episodes
assert sum([len(split.episodes) for split in splits]) == 10000 splits = dataset.get_splits(num_splits, allow_uneven_splits=False)
assert len(splits) == num_splits
dataset = _construct_dataset(10000) assert (
splits = dataset.get_splits(10, allow_uneven_splits=True) sum(map(lambda s: s.num_episodes, splits))
assert len(splits) == 10 == (num_episodes // num_splits) * num_splits
assert sum([len(split.episodes) for split in splits]) == 10000 )
def test_sample_episodes(): def test_sample_episodes():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment