Skip to content

Commit 554fb47

Browse files
thschaafThomas Schaaf
andauthored
Bugfix/_has_len (#2293)
* deal with NotImplementedError raised by torchtext * deal with NotImplementedError raised by torchtext * Added tests for dataloader which raise NotImplementedError in __len__() * Fixed some typos Co-authored-by: Thomas Schaaf <[email protected]>
1 parent 3256fe4 commit 554fb47

File tree

6 files changed

+114
-0
lines changed

6 files changed

+114
-0
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def _has_len(dataloader: DataLoader) -> bool:
5252
return True
5353
except TypeError:
5454
return False
55+
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
56+
return False
5557

5658

5759
class TrainerDataLoadingMixin(ABC):

tests/base/dataloaders.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,29 @@ def __next__(self):
2121
except StopIteration:
2222
self.iter = iter(self.dataloader)
2323
return next(self.iter)
24+
25+
26+
class CustomNotImplementedErrorDataloader:
27+
28+
def __init__(self, dataloader):
29+
self.dataloader = dataloader
30+
self.iter = iter(dataloader)
31+
self.count = 0
32+
33+
def __len__(self):
34+
"""raise NotImplementedError"""
35+
raise NotImplementedError
36+
37+
def __iter__(self):
38+
self.count = 0
39+
return self
40+
41+
def __next__(self):
42+
if self.count >= 50:
43+
raise StopIteration
44+
self.count = self.count + 1
45+
try:
46+
return next(self.iter)
47+
except StopIteration:
48+
self.iter = iter(self.dataloader)
49+
return next(self.iter)

tests/base/model_test_dataloaders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from tests.base.dataloaders import CustomInfDataloader
4+
from tests.base.dataloaders import CustomNotImplementedErrorDataloader
45

56

67
class TestDataloaderVariations(ABC):
@@ -15,6 +16,9 @@ def test_dataloader(self):
1516
def test_dataloader__infinite(self):
1617
return CustomInfDataloader(self.dataloader(train=False))
1718

19+
def test_dataloader__not_implemented_error(self):
20+
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))
21+
1822
def test_dataloader__empty(self):
1923
return None
2024

tests/base/model_train_dataloaders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from tests.base.dataloaders import CustomInfDataloader
4+
from tests.base.dataloaders import CustomNotImplementedErrorDataloader
45

56

67
class TrainDataloaderVariations(ABC):
@@ -15,6 +16,9 @@ def train_dataloader(self):
1516
def train_dataloader__infinite(self):
1617
return CustomInfDataloader(self.dataloader(train=True))
1718

19+
def train_dataloader__not_implemented_error(self):
20+
return CustomNotImplementedErrorDataloader(self.dataloader(train=True))
21+
1822
def train_dataloader__zero_length(self):
1923
dataloader = self.dataloader(train=True)
2024
dataloader.dataset.data = dataloader.dataset.data[:0]

tests/base/model_valid_dataloaders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from tests.base.dataloaders import CustomInfDataloader
4+
from tests.base.dataloaders import CustomNotImplementedErrorDataloader
45

56

67
class ValDataloaderVariations(ABC):
@@ -18,3 +19,6 @@ def val_dataloader__multiple(self):
1819

1920
def val_dataloader__infinite(self):
2021
return CustomInfDataloader(self.dataloader(train=False))
22+
23+
def val_dataloader__not_implemented_error(self):
24+
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))

tests/trainer/test_dataloaders.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,18 @@ def test_train_inf_dataloader_error(tmpdir):
295295
trainer.fit(model)
296296

297297

298+
@pytest.mark.skip('TODO: speed up this test')
299+
def test_train_not_implemented_error_dataloader_error(tmpdir):
300+
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
301+
model = EvalModelTemplate()
302+
model.train_dataloader = model.train_dataloader__not_implemented_error
303+
304+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
305+
306+
with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
307+
trainer.fit(model)
308+
309+
298310
@pytest.mark.skip('TODO: speed up this test')
299311
def test_val_inf_dataloader_error(tmpdir):
300312
"""Test inf train data loader (e.g. IterableDataset)"""
@@ -307,6 +319,18 @@ def test_val_inf_dataloader_error(tmpdir):
307319
trainer.fit(model)
308320

309321

322+
@pytest.mark.skip('TODO: speed up this test')
323+
def test_val_not_implemented_error_dataloader_error(tmpdir):
324+
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
325+
model = EvalModelTemplate()
326+
model.val_dataloader = model.val_dataloader__not_implemented_error
327+
328+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
329+
330+
with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
331+
trainer.fit(model)
332+
333+
310334
@pytest.mark.skip('TODO: speed up this test')
311335
def test_test_inf_dataloader_error(tmpdir):
312336
"""Test inf train data loader (e.g. IterableDataset)"""
@@ -319,6 +343,18 @@ def test_test_inf_dataloader_error(tmpdir):
319343
trainer.test(model)
320344

321345

346+
@pytest.mark.skip('TODO: speed up this test')
347+
def test_test_not_implemented_error_dataloader_error(tmpdir):
348+
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
349+
model = EvalModelTemplate()
350+
model.test_dataloader = model.test_dataloader__not_implemented_error
351+
352+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
353+
354+
with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
355+
trainer.test(model)
356+
357+
322358
@pytest.mark.parametrize('check_interval', [50, 1.0])
323359
@pytest.mark.skip('TODO: speed up this test')
324360
def test_inf_train_dataloader(tmpdir, check_interval):
@@ -337,6 +373,24 @@ def test_inf_train_dataloader(tmpdir, check_interval):
337373
assert result == 1
338374

339375

376+
@pytest.mark.parametrize('check_interval', [50, 1.0])
377+
@pytest.mark.skip('TODO: speed up this test')
378+
def test_not_implemented_error_train_dataloader(tmpdir, check_interval):
379+
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
380+
381+
model = EvalModelTemplate()
382+
model.train_dataloader = model.train_dataloader__not_implemented_error
383+
384+
trainer = Trainer(
385+
default_root_dir=tmpdir,
386+
max_epochs=1,
387+
val_check_interval=check_interval
388+
)
389+
result = trainer.fit(model)
390+
# verify training completed
391+
assert result == 1
392+
393+
340394
@pytest.mark.parametrize('check_interval', [1.0])
341395
@pytest.mark.skip('TODO: speed up this test')
342396
def test_inf_val_dataloader(tmpdir, check_interval):
@@ -357,6 +411,26 @@ def test_inf_val_dataloader(tmpdir, check_interval):
357411
assert result == 1
358412

359413

414+
@pytest.mark.parametrize('check_interval', [1.0])
415+
@pytest.mark.skip('TODO: speed up this test')
416+
def test_not_implemented_error_dataloader(tmpdir, check_interval):
417+
"""Test not_implemented_error data loader (e.g. IterableDataset)"""
418+
419+
model = EvalModelTemplate()
420+
model.val_dataloader = model.val_dataloader__not_implemented_error
421+
422+
# logger file to get meta
423+
trainer = Trainer(
424+
default_root_dir=tmpdir,
425+
max_epochs=1,
426+
val_check_interval=check_interval,
427+
)
428+
result = trainer.fit(model)
429+
430+
# verify training completed
431+
assert result == 1
432+
433+
360434
def test_error_on_zero_len_dataloader(tmpdir):
361435
""" Test that error is raised if a zero-length dataloader is defined """
362436

0 commit comments

Comments
 (0)