Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Habitat Lab
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mirrored_repos
MachineLearning
Meta Research
Habitat Lab
Commits
69cbd3bd
Commit
69cbd3bd
authored
5 years ago
by
danielgordon10
Committed by
Oleksandr
5 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Merging get_uneven_splits into get_splits function (#74)
* merged uneven_splits stuff into split function
parent
57691b6a
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
habitat/core/dataset.py
+31
-44
31 additions, 44 deletions
habitat/core/dataset.py
test/test_dataset.py
+22
-5
22 additions, 5 deletions
test/test_dataset.py
with
53 additions
and
49 deletions
habitat/core/dataset.py
+
31
−
44
View file @
69cbd3bd
...
...
@@ -128,10 +128,11 @@ class Dataset(Generic[T]):
def
get_splits
(
self
,
num_splits
:
int
,
max_
episodes_per_split
:
Optional
[
int
]
=
None
,
episodes_per_split
:
Optional
[
int
]
=
None
,
remove_unused_episodes
:
bool
=
False
,
collate_scene_ids
:
bool
=
True
,
sort_by_episode_id
:
bool
=
False
,
allow_uneven_splits
:
bool
=
False
,
)
->
List
[
"
Dataset
"
]:
"""
Returns a list of new datasets, each with a subset of the original
...
...
@@ -139,7 +140,7 @@ class Dataset(Generic[T]):
episodes will be duplicated.
Args:
num_splits: The number of splits to create.
max_
episodes_per_split: If provided, each split will have up to
episodes_per_split: If provided, each split will have up to
this many episodes. If it is not provided, each dataset will
have len(original_dataset.episodes) // num_splits episodes. If
max_episodes_per_split is provided and is larger than this
...
...
@@ -153,24 +154,42 @@ class Dataset(Generic[T]):
to each other because they will be in the same scene.
sort_by_episode_id: If true, sequences are sorted by their episode
ID in the returned splits.
allow_uneven_splits: If true, the last split can be shorter than
the others. This is especially useful for splitting over
validation/test datasets in order to make sure that all
episodes are copied but none are duplicated.
Returns:
A list of new datasets, each with their own subset of episodes.
"""
assert
(
len
(
self
.
episodes
)
>=
num_splits
),
"
Not enough episodes to create this many splits.
"
if
episodes_per_split
is
not
None
:
assert
not
allow_uneven_splits
,
(
"
You probably don
'
t want to specify allow_uneven_splits
"
"
and episodes_per_split.
"
)
assert
num_splits
*
episodes_per_split
<=
len
(
self
.
episodes
)
new_datasets
=
[]
if
max_episodes_per_split
is
None
:
max_episodes_per_split
=
len
(
self
.
episodes
)
//
num_splits
max_episodes_per_split
=
min
(
max_episodes_per_split
,
(
len
(
self
.
episodes
)
//
num_splits
)
)
if
allow_uneven_splits
:
stride
=
int
(
np
.
ceil
(
len
(
self
.
episodes
)
*
1.0
/
num_splits
))
split_lengths
=
[
stride
]
*
(
num_splits
-
1
)
split_lengths
.
append
(
(
len
(
self
.
episodes
)
-
stride
*
(
num_splits
-
1
))
)
else
:
if
episodes_per_split
is
not
None
:
stride
=
episodes_per_split
else
:
stride
=
len
(
self
.
episodes
)
//
num_splits
split_lengths
=
[
stride
]
*
num_splits
num_episodes
=
sum
(
split_lengths
)
rand_items
=
np
.
random
.
choice
(
len
(
self
.
episodes
),
num_splits
*
max_episodes_per_split
,
replace
=
False
,
len
(
self
.
episodes
),
num_episodes
,
replace
=
False
)
if
collate_scene_ids
:
scene_ids
=
{}
...
...
@@ -187,7 +206,7 @@ class Dataset(Generic[T]):
new_dataset
=
copy
.
copy
(
self
)
# Creates a shallow copy
new_dataset
.
episodes
=
[]
new_datasets
.
append
(
new_dataset
)
for
ii
in
range
(
max_episodes_per_split
):
for
ii
in
range
(
split_lengths
[
nn
]
):
new_dataset
.
episodes
.
append
(
self
.
episodes
[
rand_items
[
ep_ind
]])
ep_ind
+=
1
if
sort_by_episode_id
:
...
...
@@ -196,35 +215,3 @@ class Dataset(Generic[T]):
if
remove_unused_episodes
:
self
.
episodes
=
new_episodes
return
new_datasets
def
get_uneven_splits
(
self
,
num_splits
):
"""
Returns a list of new datasets, each with a subset of the original
episodes. The last dataset may have fewer episodes than the others.
This is especially useful for splitting over validation/test datasets
in order to make sure that all episodes are copied but none are
duplicated.
Args:
num_splits: The number of splits to create.
Returns:
A list of new datasets, each with their own subset of episodes.
"""
assert
(
len
(
self
.
episodes
)
>=
num_splits
),
"
Not enough episodes to create this many splits.
"
new_datasets
=
[]
num_episodes
=
len
(
self
.
episodes
)
stride
=
int
(
np
.
ceil
(
num_episodes
*
1.0
/
num_splits
))
for
ii
,
split
in
enumerate
(
range
(
0
,
num_episodes
,
stride
)[:
num_splits
]
):
new_dataset
=
copy
.
copy
(
self
)
# Creates a shallow copy
new_dataset
.
episodes
=
new_dataset
.
episodes
[
split
:
min
(
split
+
stride
,
num_episodes
)
].
copy
()
new_datasets
.
append
(
new_dataset
)
assert
(
sum
([
len
(
new_dataset
.
episodes
)
for
new_dataset
in
new_datasets
])
==
num_episodes
)
return
new_datasets
This diff is collapsed.
Click to expand it.
test/test_dataset.py
+
22
−
5
View file @
69cbd3bd
...
...
@@ -64,7 +64,7 @@ def test_get_splits_with_remainder():
assert
len
(
split
.
episodes
)
==
9
def
test_get_splits_
max
_episodes_specified
():
def
test_get_splits_
num
_episodes_specified
():
dataset
=
_construct_dataset
(
100
)
splits
=
dataset
.
get_splits
(
10
,
3
,
False
)
assert
len
(
splits
)
==
10
...
...
@@ -73,7 +73,7 @@ def test_get_splits_max_episodes_specified():
assert
len
(
dataset
.
episodes
)
==
100
dataset
=
_construct_dataset
(
100
)
splits
=
dataset
.
get_splits
(
10
,
1
1
,
False
)
splits
=
dataset
.
get_splits
(
10
,
1
0
)
assert
len
(
splits
)
==
10
for
split
in
splits
:
assert
len
(
split
.
episodes
)
==
10
...
...
@@ -86,6 +86,13 @@ def test_get_splits_max_episodes_specified():
assert
len
(
split
.
episodes
)
==
3
assert
len
(
dataset
.
episodes
)
==
30
dataset
=
_construct_dataset
(
100
)
try
:
splits
=
dataset
.
get_splits
(
10
,
20
)
assert
False
except
AssertionError
:
pass
def
test_get_splits_collate_scenes
():
dataset
=
_construct_dataset
(
10000
)
...
...
@@ -155,7 +162,17 @@ def test_get_splits_sort_by_episode_id():
def
test_get_uneven_splits
():
dataset
=
_construct_dataset
(
100
)
splits
=
dataset
.
get_uneven_splits
(
9
)
dataset
=
_construct_dataset
(
10000
)
splits
=
dataset
.
get_splits
(
9
,
allow_uneven_splits
=
False
)
assert
len
(
splits
)
==
9
assert
sum
([
len
(
split
.
episodes
)
for
split
in
splits
])
==
(
10000
//
9
)
*
9
dataset
=
_construct_dataset
(
10000
)
splits
=
dataset
.
get_splits
(
9
,
allow_uneven_splits
=
True
)
assert
len
(
splits
)
==
9
assert
sum
([
len
(
split
.
episodes
)
for
split
in
splits
])
==
100
assert
sum
([
len
(
split
.
episodes
)
for
split
in
splits
])
==
10000
dataset
=
_construct_dataset
(
10000
)
splits
=
dataset
.
get_splits
(
10
,
allow_uneven_splits
=
True
)
assert
len
(
splits
)
==
10
assert
sum
([
len
(
split
.
episodes
)
for
split
in
splits
])
==
10000
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment