|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +from functools import partial |
| 15 | +from inspect import getmembers, isfunction |
14 | 16 | from unittest import mock |
15 | | -from unittest.mock import PropertyMock |
| 17 | +from unittest.mock import ANY, PropertyMock |
16 | 18 |
|
17 | 19 | import pytest |
18 | 20 | import torch |
19 | 21 | from torch.utils.data import DataLoader |
20 | 22 |
|
21 | | -from pytorch_lightning import Trainer |
| 23 | +from pytorch_lightning import __version__, LightningDataModule, Trainer |
22 | 24 | from tests.helpers import BoringDataModule, BoringModel, RandomDataset |
23 | 25 | from tests.helpers.runif import RunIf |
24 | 26 |
|
@@ -666,107 +668,103 @@ def test_trainer_datamodule_hook_system(tmpdir): |
666 | 668 |
|
667 | 669 | class HookedDataModule(BoringDataModule): |
668 | 670 |
|
669 | | - def __init__(self): |
| 671 | + def __init__(self, called): |
670 | 672 | super().__init__() |
671 | | - self.called = [] |
672 | | - |
673 | | - def prepare_data(self): |
674 | | - self.called.append("prepare_data") |
675 | | - super().prepare_data() |
676 | | - |
677 | | - def setup(self, stage=None): |
678 | | - self.called.append(f"setup_{stage}") |
679 | | - super().setup(stage=stage) |
680 | | - |
681 | | - def teardown(self, stage=None): |
682 | | - self.called.append(f"teardown_{stage}") |
683 | | - super().teardown(stage=stage) |
684 | | - |
685 | | - def train_dataloader(self): |
686 | | - self.called.append("train_dataloader") |
687 | | - return super().train_dataloader() |
688 | | - |
689 | | - def test_dataloader(self): |
690 | | - self.called.append("test_dataloader") |
691 | | - return super().test_dataloader() |
692 | | - |
693 | | - def val_dataloader(self): |
694 | | - self.called.append("val_dataloader") |
695 | | - return super().val_dataloader() |
696 | | - |
697 | | - def predict_dataloader(self): |
698 | | - self.called.append("predict_dataloader") |
699 | | - |
700 | | - def transfer_batch_to_device(self, *args, **kwargs): |
701 | | - self.called.append("transfer_batch_to_device") |
702 | | - return super().transfer_batch_to_device(*args, **kwargs) |
703 | | - |
704 | | - def on_before_batch_transfer(self, *args, **kwargs): |
705 | | - self.called.append("on_before_batch_transfer") |
706 | | - return super().on_before_batch_transfer(*args, **kwargs) |
707 | 673 |
|
708 | | - def on_after_batch_transfer(self, *args, **kwargs): |
709 | | - self.called.append("on_after_batch_transfer") |
710 | | - return super().on_after_batch_transfer(*args, **kwargs) |
| 674 | + def call(hook, fn, *args, **kwargs): |
| 675 | + out = fn(*args, **kwargs) |
| 676 | + d = {'name': hook} |
| 677 | + if args: |
| 678 | + d['args'] = args |
| 679 | + if kwargs: |
| 680 | + d['kwargs'] = kwargs |
| 681 | + called.append(d) |
| 682 | + return out |
| 683 | + |
| 684 | + hooks = {h for h, _ in getmembers(LightningDataModule, predicate=isfunction)} |
| 685 | + for h in hooks: |
| 686 | + attr = getattr(self, h) |
| 687 | + setattr(self, h, partial(call, h, attr)) |
711 | 688 |
|
712 | 689 | model = BoringModel() |
713 | | - dm = HookedDataModule() |
714 | | - |
| 690 | + batches = 2 |
715 | 691 | trainer = Trainer( |
716 | 692 | default_root_dir=tmpdir, |
717 | 693 | max_epochs=1, |
718 | | - limit_val_batches=1, |
719 | | - limit_train_batches=2, |
720 | | - limit_test_batches=1, |
| 694 | + limit_train_batches=batches, |
| 695 | + limit_val_batches=batches, |
| 696 | + limit_test_batches=batches, |
| 697 | + limit_predict_batches=batches, |
721 | 698 | progress_bar_refresh_rate=0, |
722 | 699 | weights_summary=None, |
723 | 700 | reload_dataloaders_every_epoch=True, |
724 | 701 | ) |
| 702 | + |
| 703 | + called = [] |
| 704 | + dm = HookedDataModule(called) |
725 | 705 | trainer.fit(model, datamodule=dm) |
| 706 | + batch_transfer = [ |
| 707 | + dict(name='on_before_batch_transfer', args=(ANY, None)), |
| 708 | + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), |
| 709 | + dict(name='on_after_batch_transfer', args=(ANY, None)), |
| 710 | + ] |
726 | 711 | expected = [ |
727 | | - 'prepare_data', |
728 | | - 'setup_fit', |
729 | | - 'val_dataloader', |
730 | | - 'on_before_batch_transfer', |
731 | | - 'transfer_batch_to_device', |
732 | | - 'on_after_batch_transfer', |
733 | | - 'train_dataloader', |
734 | | - 'on_before_batch_transfer', |
735 | | - 'transfer_batch_to_device', |
736 | | - 'on_after_batch_transfer', |
737 | | - 'on_before_batch_transfer', |
738 | | - 'transfer_batch_to_device', |
739 | | - 'on_after_batch_transfer', |
740 | | - 'val_dataloader', |
741 | | - 'on_before_batch_transfer', |
742 | | - 'transfer_batch_to_device', |
743 | | - 'on_after_batch_transfer', |
744 | | - 'teardown_fit', |
| 712 | + dict(name='prepare_data'), |
| 713 | + dict(name='setup', kwargs=dict(stage='fit')), |
| 714 | + dict(name='val_dataloader'), |
| 715 | + *batch_transfer * batches, |
| 716 | + dict(name='train_dataloader'), |
| 717 | + *batch_transfer * batches, |
| 718 | + dict(name='val_dataloader'), |
| 719 | + *batch_transfer * batches, |
| 720 | + dict( |
| 721 | + name='on_save_checkpoint', |
| 722 | + args=({ |
| 723 | + 'callbacks': ANY, |
| 724 | + 'epoch': 1, |
| 725 | + 'global_step': 2, |
| 726 | + 'lr_schedulers': ANY, |
| 727 | + 'optimizer_states': ANY, |
| 728 | + 'pytorch-lightning_version': __version__, |
| 729 | + 'state_dict': ANY |
| 730 | + }, ) |
| 731 | + ), |
| 732 | + dict(name='teardown', kwargs=dict(stage='fit')), |
745 | 733 | ] |
746 | | - assert dm.called == expected |
| 734 | + assert called == expected |
747 | 735 |
|
748 | | - dm = HookedDataModule() |
| 736 | + called = [] |
| 737 | + dm = HookedDataModule(called) |
749 | 738 | trainer.validate(model, datamodule=dm, verbose=False) |
750 | 739 | expected = [ |
751 | | - 'prepare_data', |
752 | | - 'setup_validate', |
753 | | - 'val_dataloader', |
754 | | - 'on_before_batch_transfer', |
755 | | - 'transfer_batch_to_device', |
756 | | - 'on_after_batch_transfer', |
757 | | - 'teardown_validate', |
| 740 | + dict(name='prepare_data'), |
| 741 | + dict(name='setup', kwargs=dict(stage='validate')), |
| 742 | + dict(name='val_dataloader'), |
| 743 | + *batch_transfer * batches, |
| 744 | + dict(name='teardown', kwargs=dict(stage='validate')), |
758 | 745 | ] |
759 | | - assert dm.called == expected |
| 746 | + assert called == expected |
760 | 747 |
|
761 | | - dm = HookedDataModule() |
| 748 | + called = [] |
| 749 | + dm = HookedDataModule(called) |
762 | 750 | trainer.test(model, datamodule=dm, verbose=False) |
763 | 751 | expected = [ |
764 | | - 'prepare_data', |
765 | | - 'setup_test', |
766 | | - 'test_dataloader', |
767 | | - 'on_before_batch_transfer', |
768 | | - 'transfer_batch_to_device', |
769 | | - 'on_after_batch_transfer', |
770 | | - 'teardown_test', |
| 752 | + dict(name='prepare_data'), |
| 753 | + dict(name='setup', kwargs=dict(stage='test')), |
| 754 | + dict(name='test_dataloader'), |
| 755 | + *batch_transfer * batches, |
| 756 | + dict(name='teardown', kwargs=dict(stage='test')), |
| 757 | + ] |
| 758 | + assert called == expected |
| 759 | + |
| 760 | + called = [] |
| 761 | + dm = HookedDataModule(called) |
| 762 | + trainer.predict(model, datamodule=dm) |
| 763 | + expected = [ |
| 764 | + dict(name='prepare_data'), |
| 765 | + dict(name='setup', kwargs=dict(stage='predict')), |
| 766 | + dict(name='predict_dataloader'), |
| 767 | + *batch_transfer * batches, |
| 768 | + dict(name='teardown', kwargs=dict(stage='predict')), |
771 | 769 | ] |
772 | | - assert dm.called == expected |
| 770 | + assert called == expected |
0 commit comments