diff --git a/CHANGELOG.md b/CHANGELOG.md index 81846809fbf85..c917ee94f8c5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -197,6 +197,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828)) + + ## [1.2.6] - 2021-03-30 ### Changed diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d565f0906e59e..27dcd6fe9aa0d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -866,7 +866,7 @@ def validate( self.validating = True # If you supply a datamodule you can't supply val_dataloaders - if val_dataloaders and datamodule: + if val_dataloaders is not None and datamodule: raise MisconfigurationException( 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' ) @@ -928,7 +928,7 @@ def test( self.testing = True # If you supply a datamodule you can't supply test_dataloaders - if test_dataloaders and datamodule: + if test_dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`') model_provided = model is not None @@ -1024,7 +1024,7 @@ def predict( self.state = TrainerState.PREDICTING self.predicting = True - if dataloaders and datamodule: + if dataloaders is not None and datamodule: raise MisconfigurationException( 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' ) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 505af173b7910..7f9cf6210ce7c 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -636,28 +636,42 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ - model = EvalModelTemplate() + model = BoringModel() original_dataset = model.train_dataloader().dataset - class IterableWithLen(IterableDataset): + class IterableWithoutLen(IterableDataset): def __iter__(self): return iter(original_dataset) + class IterableWithLen(IterableWithoutLen): + def __len__(self): return len(original_dataset) + # with __len__ defined dataloader = DataLoader(IterableWithLen(), batch_size=16) assert has_len(dataloader) assert has_iterable_dataset(dataloader) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=3, - ) + trainer = Trainer(default_root_dir=tmpdir, max_steps=3) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.validate(model, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.test(model, test_dataloaders=[dataloader]) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.predict(model, dataloaders=[dataloader]) + + # without __len__ defined + dataloader = DataLoader(IterableWithoutLen(), batch_size=16) + assert not has_len(dataloader) + assert has_iterable_dataset(dataloader) + trainer = Trainer(default_root_dir=tmpdir, max_steps=3) + trainer.validate(model, val_dataloaders=dataloader) + trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) + trainer.test(model, test_dataloaders=dataloader) + trainer.predict(model, dataloaders=dataloader) @RunIf(min_gpus=2)