diff --git a/CHANGELOG.md b/CHANGELOG.md index 784a1581ee97a..8f0bf3eb0e09c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -226,6 +226,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924)) +- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941)) + + ## [1.3.5] - 2021-06-08 ### Added diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index c06ced6662d81..f5487f335a035 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -98,7 +98,7 @@ def _get_num_dataloaders(self, dataloaders: List[DataLoader]) -> int: def _build_kwargs(self, batch, batch_idx, dataloader_idx): step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) - if self.num_dataloaders: + if self.num_dataloaders > 1: step_kwargs['dataloader_idx'] = dataloader_idx return step_kwargs diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index fff9d2de79f77..185baac51f41f 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -161,7 +161,7 @@ def __init__(self, data_dir: str = './'): self.checkpoint_state: Optional[str] = None def prepare_data(self): - self.random_full = RandomDataset(32, 192) + self.random_full = RandomDataset(32, 64 * 4) def setup(self, stage: Optional[str] = None): if stage == "fit" or stage is None: @@ -169,12 +169,16 @@ def setup(self, stage: Optional[str] = None): self.dims = self.random_train[0].shape if stage in ("fit", "validate") or stage is None: - self.random_val = Subset(self.random_full, indices=range(64, 128)) + self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) if stage == "test" or stage is None: - self.random_test = Subset(self.random_full, indices=range(128, 192)) + self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) self.dims = getattr(self, "dims", self.random_test[0].shape) + if stage == "predict" or stage is None: + self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) + self.dims = getattr(self, "dims", self.random_predict[0].shape) + def train_dataloader(self): return DataLoader(self.random_train) @@ -183,3 +187,6 @@ def val_dataloader(self): def test_dataloader(self): return DataLoader(self.random_test) + + def predict_dataloader(self): + return DataLoader(self.random_predict) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 60354c987fab3..9904e43be69b4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,14 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial +from inspect import getmembers, isfunction from unittest import mock -from unittest.mock import PropertyMock +from unittest.mock import ANY, PropertyMock import pytest import torch from torch.utils.data import DataLoader -from pytorch_lightning import Trainer +from pytorch_lightning import __version__, LightningDataModule, Trainer from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -666,107 +668,103 @@ def test_trainer_datamodule_hook_system(tmpdir): class HookedDataModule(BoringDataModule): - def __init__(self): + def __init__(self, called): super().__init__() - self.called = [] - - def prepare_data(self): - self.called.append("prepare_data") - super().prepare_data() - - def setup(self, stage=None): - self.called.append(f"setup_{stage}") - super().setup(stage=stage) - - def teardown(self, stage=None): - self.called.append(f"teardown_{stage}") - super().teardown(stage=stage) - - def train_dataloader(self): - self.called.append("train_dataloader") - return super().train_dataloader() - - def test_dataloader(self): - self.called.append("test_dataloader") - return super().test_dataloader() - - def val_dataloader(self): - self.called.append("val_dataloader") - return super().val_dataloader() - - def predict_dataloader(self): - self.called.append("predict_dataloader") - - def transfer_batch_to_device(self, *args, **kwargs): - self.called.append("transfer_batch_to_device") - return super().transfer_batch_to_device(*args, **kwargs) - - def on_before_batch_transfer(self, *args, **kwargs): - self.called.append("on_before_batch_transfer") - return super().on_before_batch_transfer(*args, **kwargs) - def on_after_batch_transfer(self, *args, **kwargs): - self.called.append("on_after_batch_transfer") - return super().on_after_batch_transfer(*args, **kwargs) + def call(hook, fn, *args, **kwargs): + out = fn(*args, **kwargs) + d = {'name': hook} + if args: + d['args'] = args + if kwargs: + d['kwargs'] = kwargs + called.append(d) + return out + + hooks = {h for h, _ in getmembers(LightningDataModule, predicate=isfunction)} + for h in hooks: + attr = getattr(self, h) + setattr(self, h, partial(call, h, attr)) model = BoringModel() - dm = HookedDataModule() - + batches = 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_val_batches=1, - limit_train_batches=2, - limit_test_batches=1, + limit_train_batches=batches, + limit_val_batches=batches, + limit_test_batches=batches, + limit_predict_batches=batches, progress_bar_refresh_rate=0, weights_summary=None, reload_dataloaders_every_epoch=True, ) + + called = [] + dm = HookedDataModule(called) trainer.fit(model, datamodule=dm) + batch_transfer = [ + dict(name='on_before_batch_transfer', args=(ANY, None)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), + dict(name='on_after_batch_transfer', args=(ANY, None)), + ] expected = [ - 'prepare_data', - 'setup_fit', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'train_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_fit', + dict(name='prepare_data'), + dict(name='setup', kwargs=dict(stage='fit')), + dict(name='val_dataloader'), + *batch_transfer * batches, + dict(name='train_dataloader'), + *batch_transfer * batches, + dict(name='val_dataloader'), + *batch_transfer * batches, + dict( + name='on_save_checkpoint', + args=({ + 'callbacks': ANY, + 'epoch': 1, + 'global_step': 2, + 'lr_schedulers': ANY, + 'optimizer_states': ANY, + 'pytorch-lightning_version': __version__, + 'state_dict': ANY + }, ) + ), + dict(name='teardown', kwargs=dict(stage='fit')), ] - assert dm.called == expected + assert called == expected - dm = HookedDataModule() + called = [] + dm = HookedDataModule(called) trainer.validate(model, datamodule=dm, verbose=False) expected = [ - 'prepare_data', - 'setup_validate', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_validate', + dict(name='prepare_data'), + dict(name='setup', kwargs=dict(stage='validate')), + dict(name='val_dataloader'), + *batch_transfer * batches, + dict(name='teardown', kwargs=dict(stage='validate')), ] - assert dm.called == expected + assert called == expected - dm = HookedDataModule() + called = [] + dm = HookedDataModule(called) trainer.test(model, datamodule=dm, verbose=False) expected = [ - 'prepare_data', - 'setup_test', - 'test_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_test', + dict(name='prepare_data'), + dict(name='setup', kwargs=dict(stage='test')), + dict(name='test_dataloader'), + *batch_transfer * batches, + dict(name='teardown', kwargs=dict(stage='test')), + ] + assert called == expected + + called = [] + dm = HookedDataModule(called) + trainer.predict(model, datamodule=dm) + expected = [ + dict(name='prepare_data'), + dict(name='setup', kwargs=dict(stage='predict')), + dict(name='predict_dataloader'), + *batch_transfer * batches, + dict(name='teardown', kwargs=dict(stage='predict')), ] - assert dm.called == expected + assert called == expected