From 57691b6af3dd474563c416a365f048c58f720834 Mon Sep 17 00:00:00 2001
From: danielgordon10 <danielgordon10@gmail.com>
Date: Wed, 24 Apr 2019 15:54:36 -0700
Subject: [PATCH] Dataset utility functions (#49)

* Made simple agents suitable for challenge submissions.

* Update docker

* Addressed the comments, chaned flag names, switched to benchmark

* Removed .swp file, added swp to gitignore.

* Adding arxiv identifier

* dataset changes

* added tests, changed interface slightly

* added collate parameter

* addressed oleksandr's comments
---
 habitat/core/dataset.py | 131 +++++++++++++++++++++++++++++++-
 test/test_dataset.py    | 161 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 290 insertions(+), 2 deletions(-)
 create mode 100644 test/test_dataset.py

diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py
index 4bbb5384c..7e02927e2 100644
--- a/habitat/core/dataset.py
+++ b/habitat/core/dataset.py
@@ -4,8 +4,12 @@
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
+import copy
 import json
-from typing import Dict, List, Type, TypeVar, Generic, Optional
+import random
+from typing import Dict, List, Type, TypeVar, Generic, Optional, Callable
+
+import numpy as np
 
 
 class Episode:
@@ -67,7 +71,7 @@ class Dataset(Generic[T]):
         Returns:
             unique scene ids present in the dataset
         """
-        return list({episode.scene_id for episode in self.episodes})
+        return sorted(list({episode.scene_id for episode in self.episodes}))
 
     def get_scene_episodes(self, scene_id: str) -> List[T]:
         """
@@ -101,3 +105,126 @@ class Dataset(Generic[T]):
 
     def from_json(self, json_str: str) -> None:
         raise NotImplementedError
+
+    def filter_episodes(
+        self, filter_fn: Callable[[Episode], bool]
+    ) -> "Dataset":
+        """
+        Returns a new dataset with only the filtered episodes from the original
+        dataset.
+        Args:
+            filter_fn: Function used to filter the episodes.
+        Returns:
+            The new dataset.
+        """
+        new_episodes = []
+        for episode in self.episodes:
+            if filter_fn(episode):
+                new_episodes.append(episode)
+        new_dataset = copy.copy(self)
+        new_dataset.episodes = new_episodes
+        return new_dataset
+
+    def get_splits(
+        self,
+        num_splits: int,
+        max_episodes_per_split: Optional[int] = None,
+        remove_unused_episodes: bool = False,
+        collate_scene_ids: bool = True,
+        sort_by_episode_id: bool = False,
+    ) -> List["Dataset"]:
+        """
+        Returns a list of new datasets, each with a subset of the original
+        episodes. All splits will have the same number of episodes, but no
+        episodes will be duplicated.
+        Args:
+            num_splits: The number of splits to create.
+            max_episodes_per_split: If provided, each split will have up to
+                this many episodes. If it is not provided, each dataset will
+                have len(original_dataset.episodes) // num_splits episodes. If
+                max_episodes_per_split is provided and is larger than this
+                value, it will be capped to this value.
+            remove_unused_episodes: Once the splits are created, the extra
+                episodes will be destroyed from the original dataset. This
+                saves memory for large datasets.
+            collate_scene_ids: If true, episodes with the same scene id are
+                next to each other. This saves on overhead of switching between
+                scenes, but means multiple sequential episodes will be related
+                to each other because they will be in the same scene.
+            sort_by_episode_id: If true, sequences are sorted by their episode
+                ID in the returned splits.
+        Returns:
+            A list of new datasets, each with their own subset of episodes.
+        """
+
+        assert (
+            len(self.episodes) >= num_splits
+        ), "Not enough episodes to create this many splits."
+
+        new_datasets = []
+        if max_episodes_per_split is None:
+            max_episodes_per_split = len(self.episodes) // num_splits
+        max_episodes_per_split = min(
+            max_episodes_per_split, (len(self.episodes) // num_splits)
+        )
+        rand_items = np.random.choice(
+            len(self.episodes),
+            num_splits * max_episodes_per_split,
+            replace=False,
+        )
+        if collate_scene_ids:
+            scene_ids = {}
+            for rand_ind in rand_items:
+                scene = self.episodes[rand_ind].scene_id
+                if scene not in scene_ids:
+                    scene_ids[scene] = []
+                scene_ids[scene].append(rand_ind)
+            rand_items = []
+            list(map(rand_items.extend, scene_ids.values()))
+        ep_ind = 0
+        new_episodes = []
+        for nn in range(num_splits):
+            new_dataset = copy.copy(self)  # Creates a shallow copy
+            new_dataset.episodes = []
+            new_datasets.append(new_dataset)
+            for ii in range(max_episodes_per_split):
+                new_dataset.episodes.append(self.episodes[rand_items[ep_ind]])
+                ep_ind += 1
+            if sort_by_episode_id:
+                new_dataset.episodes.sort(key=lambda ep: ep.episode_id)
+            new_episodes.extend(new_dataset.episodes)
+        if remove_unused_episodes:
+            self.episodes = new_episodes
+        return new_datasets
+
+    def get_uneven_splits(self, num_splits):
+        """
+        Returns a list of new datasets, each with a subset of the original
+        episodes. The last dataset may have fewer episodes than the others.
+        This is especially useful for splitting over validation/test datasets
+        in order to make sure that all episodes are copied but none are
+        duplicated.
+        Args:
+            num_splits: The number of splits to create.
+        Returns:
+            A list of new datasets, each with their own subset of episodes.
+        """
+        assert (
+            len(self.episodes) >= num_splits
+        ), "Not enough episodes to create this many splits."
+        new_datasets = []
+        num_episodes = len(self.episodes)
+        stride = int(np.ceil(num_episodes * 1.0 / num_splits))
+        for ii, split in enumerate(
+            range(0, num_episodes, stride)[:num_splits]
+        ):
+            new_dataset = copy.copy(self)  # Creates a shallow copy
+            new_dataset.episodes = new_dataset.episodes[
+                split : min(split + stride, num_episodes)
+            ].copy()
+            new_datasets.append(new_dataset)
+        assert (
+            sum([len(new_dataset.episodes) for new_dataset in new_datasets])
+            == num_episodes
+        )
+        return new_datasets
diff --git a/test/test_dataset.py b/test/test_dataset.py
new file mode 100644
index 000000000..69b2bf69f
--- /dev/null
+++ b/test/test_dataset.py
@@ -0,0 +1,161 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from habitat.core.dataset import Dataset, Episode
+
+
+def _construct_dataset(num_episodes):
+    episodes = []
+    for ii in range(num_episodes):
+        episode = Episode(
+            episode_id=str(ii),
+            scene_id="scene_id_" + str(ii % 10),
+            start_position=[0, 0, 0],
+            start_rotation=[0, 0, 0, 1],
+        )
+        episodes.append(episode)
+    dataset = Dataset()
+    dataset.episodes = episodes
+    return dataset
+
+
+def test_scene_ids():
+    dataset = _construct_dataset(100)
+    assert dataset.scene_ids == ["scene_id_" + str(ii) for ii in range(10)]
+
+
+def test_get_scene_episodes():
+    dataset = _construct_dataset(100)
+    scene = "scene_id_0"
+    scene_episodes = dataset.get_scene_episodes(scene)
+    assert len(scene_episodes) == 10
+    for ep in scene_episodes:
+        assert ep.scene_id == scene
+
+
+def test_filter_episodes():
+    dataset = _construct_dataset(100)
+
+    def filter_fn(episode: Episode) -> bool:
+        return int(episode.episode_id) % 2 == 0
+
+    filtered_dataset = dataset.filter_episodes(filter_fn)
+    assert len(filtered_dataset.episodes) == 50
+    for ep in filtered_dataset.episodes:
+        assert filter_fn(ep)
+
+
+def test_get_splits_even_split_possible():
+    dataset = _construct_dataset(100)
+    splits = dataset.get_splits(10)
+    assert len(splits) == 10
+    for split in splits:
+        assert len(split.episodes) == 10
+
+
+def test_get_splits_with_remainder():
+    dataset = _construct_dataset(100)
+    splits = dataset.get_splits(11)
+    assert len(splits) == 11
+    for split in splits:
+        assert len(split.episodes) == 9
+
+
+def test_get_splits_max_episodes_specified():
+    dataset = _construct_dataset(100)
+    splits = dataset.get_splits(10, 3, False)
+    assert len(splits) == 10
+    for split in splits:
+        assert len(split.episodes) == 3
+    assert len(dataset.episodes) == 100
+
+    dataset = _construct_dataset(100)
+    splits = dataset.get_splits(10, 11, False)
+    assert len(splits) == 10
+    for split in splits:
+        assert len(split.episodes) == 10
+    assert len(dataset.episodes) == 100
+
+    dataset = _construct_dataset(100)
+    splits = dataset.get_splits(10, 3, True)
+    assert len(splits) == 10
+    for split in splits:
+        assert len(split.episodes) == 3
+    assert len(dataset.episodes) == 30
+
+
+def test_get_splits_collate_scenes():
+    dataset = _construct_dataset(10000)
+    splits = dataset.get_splits(10, 23, collate_scene_ids=True)
+    assert len(splits) == 10
+    for split in splits:
+        assert len(split.episodes) == 23
+        prev_ids = set()
+        for ii, ep in enumerate(split.episodes):
+            if ep.scene_id not in prev_ids:
+                prev_ids.add(ep.scene_id)
+            else:
+                assert split.episodes[ii - 1].scene_id == ep.scene_id
+
+    dataset = _construct_dataset(10000)
+    splits = dataset.get_splits(10, 200, collate_scene_ids=False)
+    assert len(splits) == 10
+    for split in splits:
+        prev_ids = set()
+        found_not_collated = False
+        for ii, ep in enumerate(split.episodes):
+            if ep.scene_id not in prev_ids:
+                prev_ids.add(ep.scene_id)
+            else:
+                if split.episodes[ii - 1].scene_id != ep.scene_id:
+                    found_not_collated = True
+                    break
+        assert found_not_collated
+
+    dataset = _construct_dataset(10000)
+    splits = dataset.get_splits(10, collate_scene_ids=True)
+    assert len(splits) == 10
+    for split in splits:
+        assert len(split.episodes) == 1000
+        prev_ids = set()
+        for ii, ep in enumerate(split.episodes):
+            if ep.scene_id not in prev_ids:
+                prev_ids.add(ep.scene_id)
+            else:
+                assert split.episodes[ii - 1].scene_id == ep.scene_id
+
+    dataset = _construct_dataset(10000)
+    splits = dataset.get_splits(10, collate_scene_ids=False)
+    assert len(splits) == 10
+    for split in splits:
+        prev_ids = set()
+        found_not_collated = False
+        for ii, ep in enumerate(split.episodes):
+            if ep.scene_id not in prev_ids:
+                prev_ids.add(ep.scene_id)
+            else:
+                if split.episodes[ii - 1].scene_id != ep.scene_id:
+                    found_not_collated = True
+                    break
+        assert found_not_collated
+
+
+def test_get_splits_sort_by_episode_id():
+    dataset = _construct_dataset(10000)
+    splits = dataset.get_splits(10, 23, sort_by_episode_id=True)
+    assert len(splits) == 10
+    for split in splits:
+        assert len(split.episodes) == 23
+        for ii, ep in enumerate(split.episodes):
+            if ii > 0:
+                assert ep.episode_id >= split.episodes[ii - 1].episode_id
+
+
+def test_get_uneven_splits():
+    dataset = _construct_dataset(100)
+    splits = dataset.get_uneven_splits(9)
+    assert len(splits) == 9
+    assert sum([len(split.episodes) for split in splits]) == 100
-- 
GitLab