Skip to content

Commit 2ccc745

Browse files
SkafteNickiNicki SkafteBorda
authored
Error on zero length dataloaders (#1280)
* error_on_zero_length * update CHANGELOG.md * added test * Update pytorch_lightning/trainer/data_loading.py Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 09167ef commit 2ccc745

File tree

5 files changed

+43
-3
lines changed

5 files changed

+43
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
1919
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
2020
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
21+
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
2122

2223
### Changed
2324

pytorch_lightning/trainer/data_loading.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@
2626

2727

2828
def _has_len(dataloader: DataLoader) -> bool:
29+
""" Checks if a given Dataloader has __len__ method implemented i.e. if
30+
it is a finite dataloader or infinite dataloader """
2931
try:
3032
# try getting the length
31-
_ = len(dataloader)
33+
if len(dataloader) == 0:
34+
raise ValueError('Dataloader returned 0 length. Please make sure'
35+
' that your Dataloader atleast returns 1 batch')
3236
return True
3337
except TypeError:
3438
return False

tests/base/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
LightTestOptimizerWithSchedulingMixin,
2626
LightTestMultipleOptimizersWithSchedulingMixin,
2727
LightTestOptimizersWithMixedSchedulingMixin,
28-
LightTestReduceLROnPlateauMixin
28+
LightTestReduceLROnPlateauMixin,
29+
LightZeroLenDataloader
2930
)
3031

3132

tests/base/mixins.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,16 @@ def test_dataloader(self):
252252
return CustomInfDataloader(self._dataloader(train=False))
253253

254254

255+
class LightZeroLenDataloader:
256+
""" Simple dataloader that has zero length. """
257+
258+
def train_dataloader(self):
259+
dataloader = self._dataloader(train=True)
260+
dataloader.dataset.data = dataloader.dataset.data[:0]
261+
dataloader.dataset.targets = dataloader.dataset.targets[:0]
262+
return dataloader
263+
264+
255265
class LightEmptyTestStep:
256266
"""Empty test step."""
257267

tests/trainer/test_dataloaders.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
LightTrainDataloader,
1717
LightInfTrainDataloader,
1818
LightInfValDataloader,
19-
LightInfTestDataloader
19+
LightInfTestDataloader,
20+
LightZeroLenDataloader
2021
)
2122

2223

@@ -458,3 +459,26 @@ class CurrentTestModel(
458459

459460
# verify training completed
460461
assert result == 1
462+
463+
464+
def test_error_on_zero_len_dataloader(tmpdir):
465+
""" Test that error is raised if a zero-length dataloader is defined """
466+
tutils.reset_seed()
467+
468+
class CurrentTestModel(
469+
LightZeroLenDataloader,
470+
LightningTestModel
471+
):
472+
pass
473+
474+
hparams = tutils.get_default_hparams()
475+
model = CurrentTestModel(hparams)
476+
477+
# fit model
478+
with pytest.raises(ValueError):
479+
trainer = Trainer(
480+
default_save_path=tmpdir,
481+
max_epochs=1,
482+
test_percent_check=0.5
483+
)
484+
trainer.fit(model)

0 commit comments

Comments
 (0)