Skip to content

Commit 436fc53

Browse files
carmoccakaushikb11
andauthored
Improve LightningDataModule hook test and fix dataloader_idx argument (#7941)
Co-authored-by: Kaushik B <[email protected]>
1 parent 6b7b404 commit 436fc53

File tree

4 files changed

+96
-88
lines changed

4 files changed

+96
-88
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
229229
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
230230

231231

232+
- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))
233+
234+
232235
## [1.3.5] - 2021-06-08
233236

234237
### Added

pytorch_lightning/trainer/predict_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _get_num_dataloaders(self, dataloaders: List[DataLoader]) -> int:
9898

9999
def _build_kwargs(self, batch, batch_idx, dataloader_idx):
100100
step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)])
101-
if self.num_dataloaders:
101+
if self.num_dataloaders > 1:
102102
step_kwargs['dataloader_idx'] = dataloader_idx
103103
return step_kwargs
104104

tests/helpers/boring_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,20 +161,24 @@ def __init__(self, data_dir: str = './'):
161161
self.checkpoint_state: Optional[str] = None
162162

163163
def prepare_data(self):
164-
self.random_full = RandomDataset(32, 192)
164+
self.random_full = RandomDataset(32, 64 * 4)
165165

166166
def setup(self, stage: Optional[str] = None):
167167
if stage == "fit" or stage is None:
168168
self.random_train = Subset(self.random_full, indices=range(64))
169169
self.dims = self.random_train[0].shape
170170

171171
if stage in ("fit", "validate") or stage is None:
172-
self.random_val = Subset(self.random_full, indices=range(64, 128))
172+
self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
173173

174174
if stage == "test" or stage is None:
175-
self.random_test = Subset(self.random_full, indices=range(128, 192))
175+
self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
176176
self.dims = getattr(self, "dims", self.random_test[0].shape)
177177

178+
if stage == "predict" or stage is None:
179+
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
180+
self.dims = getattr(self, "dims", self.random_predict[0].shape)
181+
178182
def train_dataloader(self):
179183
return DataLoader(self.random_train)
180184

@@ -183,3 +187,6 @@ def val_dataloader(self):
183187

184188
def test_dataloader(self):
185189
return DataLoader(self.random_test)
190+
191+
def predict_dataloader(self):
192+
return DataLoader(self.random_predict)

tests/models/test_hooks.py

Lines changed: 82 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from functools import partial
15+
from inspect import getmembers, isfunction
1416
from unittest import mock
15-
from unittest.mock import PropertyMock
17+
from unittest.mock import ANY, PropertyMock
1618

1719
import pytest
1820
import torch
1921
from torch.utils.data import DataLoader
2022

21-
from pytorch_lightning import Trainer
23+
from pytorch_lightning import __version__, LightningDataModule, Trainer
2224
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
2325
from tests.helpers.runif import RunIf
2426

@@ -666,107 +668,103 @@ def test_trainer_datamodule_hook_system(tmpdir):
666668

667669
class HookedDataModule(BoringDataModule):
668670

669-
def __init__(self):
671+
def __init__(self, called):
670672
super().__init__()
671-
self.called = []
672-
673-
def prepare_data(self):
674-
self.called.append("prepare_data")
675-
super().prepare_data()
676-
677-
def setup(self, stage=None):
678-
self.called.append(f"setup_{stage}")
679-
super().setup(stage=stage)
680-
681-
def teardown(self, stage=None):
682-
self.called.append(f"teardown_{stage}")
683-
super().teardown(stage=stage)
684-
685-
def train_dataloader(self):
686-
self.called.append("train_dataloader")
687-
return super().train_dataloader()
688-
689-
def test_dataloader(self):
690-
self.called.append("test_dataloader")
691-
return super().test_dataloader()
692-
693-
def val_dataloader(self):
694-
self.called.append("val_dataloader")
695-
return super().val_dataloader()
696-
697-
def predict_dataloader(self):
698-
self.called.append("predict_dataloader")
699-
700-
def transfer_batch_to_device(self, *args, **kwargs):
701-
self.called.append("transfer_batch_to_device")
702-
return super().transfer_batch_to_device(*args, **kwargs)
703-
704-
def on_before_batch_transfer(self, *args, **kwargs):
705-
self.called.append("on_before_batch_transfer")
706-
return super().on_before_batch_transfer(*args, **kwargs)
707673

708-
def on_after_batch_transfer(self, *args, **kwargs):
709-
self.called.append("on_after_batch_transfer")
710-
return super().on_after_batch_transfer(*args, **kwargs)
674+
def call(hook, fn, *args, **kwargs):
675+
out = fn(*args, **kwargs)
676+
d = {'name': hook}
677+
if args:
678+
d['args'] = args
679+
if kwargs:
680+
d['kwargs'] = kwargs
681+
called.append(d)
682+
return out
683+
684+
hooks = {h for h, _ in getmembers(LightningDataModule, predicate=isfunction)}
685+
for h in hooks:
686+
attr = getattr(self, h)
687+
setattr(self, h, partial(call, h, attr))
711688

712689
model = BoringModel()
713-
dm = HookedDataModule()
714-
690+
batches = 2
715691
trainer = Trainer(
716692
default_root_dir=tmpdir,
717693
max_epochs=1,
718-
limit_val_batches=1,
719-
limit_train_batches=2,
720-
limit_test_batches=1,
694+
limit_train_batches=batches,
695+
limit_val_batches=batches,
696+
limit_test_batches=batches,
697+
limit_predict_batches=batches,
721698
progress_bar_refresh_rate=0,
722699
weights_summary=None,
723700
reload_dataloaders_every_epoch=True,
724701
)
702+
703+
called = []
704+
dm = HookedDataModule(called)
725705
trainer.fit(model, datamodule=dm)
706+
batch_transfer = [
707+
dict(name='on_before_batch_transfer', args=(ANY, None)),
708+
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
709+
dict(name='on_after_batch_transfer', args=(ANY, None)),
710+
]
726711
expected = [
727-
'prepare_data',
728-
'setup_fit',
729-
'val_dataloader',
730-
'on_before_batch_transfer',
731-
'transfer_batch_to_device',
732-
'on_after_batch_transfer',
733-
'train_dataloader',
734-
'on_before_batch_transfer',
735-
'transfer_batch_to_device',
736-
'on_after_batch_transfer',
737-
'on_before_batch_transfer',
738-
'transfer_batch_to_device',
739-
'on_after_batch_transfer',
740-
'val_dataloader',
741-
'on_before_batch_transfer',
742-
'transfer_batch_to_device',
743-
'on_after_batch_transfer',
744-
'teardown_fit',
712+
dict(name='prepare_data'),
713+
dict(name='setup', kwargs=dict(stage='fit')),
714+
dict(name='val_dataloader'),
715+
*batch_transfer * batches,
716+
dict(name='train_dataloader'),
717+
*batch_transfer * batches,
718+
dict(name='val_dataloader'),
719+
*batch_transfer * batches,
720+
dict(
721+
name='on_save_checkpoint',
722+
args=({
723+
'callbacks': ANY,
724+
'epoch': 1,
725+
'global_step': 2,
726+
'lr_schedulers': ANY,
727+
'optimizer_states': ANY,
728+
'pytorch-lightning_version': __version__,
729+
'state_dict': ANY
730+
}, )
731+
),
732+
dict(name='teardown', kwargs=dict(stage='fit')),
745733
]
746-
assert dm.called == expected
734+
assert called == expected
747735

748-
dm = HookedDataModule()
736+
called = []
737+
dm = HookedDataModule(called)
749738
trainer.validate(model, datamodule=dm, verbose=False)
750739
expected = [
751-
'prepare_data',
752-
'setup_validate',
753-
'val_dataloader',
754-
'on_before_batch_transfer',
755-
'transfer_batch_to_device',
756-
'on_after_batch_transfer',
757-
'teardown_validate',
740+
dict(name='prepare_data'),
741+
dict(name='setup', kwargs=dict(stage='validate')),
742+
dict(name='val_dataloader'),
743+
*batch_transfer * batches,
744+
dict(name='teardown', kwargs=dict(stage='validate')),
758745
]
759-
assert dm.called == expected
746+
assert called == expected
760747

761-
dm = HookedDataModule()
748+
called = []
749+
dm = HookedDataModule(called)
762750
trainer.test(model, datamodule=dm, verbose=False)
763751
expected = [
764-
'prepare_data',
765-
'setup_test',
766-
'test_dataloader',
767-
'on_before_batch_transfer',
768-
'transfer_batch_to_device',
769-
'on_after_batch_transfer',
770-
'teardown_test',
752+
dict(name='prepare_data'),
753+
dict(name='setup', kwargs=dict(stage='test')),
754+
dict(name='test_dataloader'),
755+
*batch_transfer * batches,
756+
dict(name='teardown', kwargs=dict(stage='test')),
757+
]
758+
assert called == expected
759+
760+
called = []
761+
dm = HookedDataModule(called)
762+
trainer.predict(model, datamodule=dm)
763+
expected = [
764+
dict(name='prepare_data'),
765+
dict(name='setup', kwargs=dict(stage='predict')),
766+
dict(name='predict_dataloader'),
767+
*batch_transfer * batches,
768+
dict(name='teardown', kwargs=dict(stage='predict')),
771769
]
772-
assert dm.called == expected
770+
assert called == expected

0 commit comments

Comments
 (0)