|
17 | 17 | from torch.utils.data import DataLoader |
18 | 18 |
|
19 | 19 | from pytorch_lightning import Trainer |
20 | | -from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource |
| 20 | +from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache |
21 | 21 | from pytorch_lightning.trainer.states import TrainerFn |
22 | 22 | from pytorch_lightning.utilities.warnings import PossibleUserWarning |
| 23 | +from tests.deprecated_api import no_warning_call |
23 | 24 | from tests.helpers import BoringDataModule, BoringModel |
24 | 25 | from tests.helpers.boring_model import RandomDataset |
25 | 26 |
|
@@ -71,6 +72,77 @@ def test_dataloader_source_request_from_module(): |
71 | 72 | module.foo.assert_called_once() |
72 | 73 |
|
73 | 74 |
|
| 75 | +@pytest.mark.parametrize( |
| 76 | + "hook_name", ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") |
| 77 | +) |
| 78 | +class TestDataHookSelector: |
| 79 | + def overridden_func(self, batch, *args, **kwargs): |
| 80 | + return batch |
| 81 | + |
| 82 | + def reset_instances(self): |
| 83 | + return BoringDataModule(), BoringModel(), Trainer() |
| 84 | + |
| 85 | + def test_no_datamodule_no_overridden(self, hook_name): |
| 86 | + model, _, trainer = self.reset_instances() |
| 87 | + trainer._data_connector.attach_datamodule(model, datamodule=None) |
| 88 | + with no_warning_call(match="have overridden `{hook_name}` in both"): |
| 89 | + hook = trainer._data_connector._datahook_selector.get_hook(hook_name) |
| 90 | + |
| 91 | + assert hook == getattr(model, hook_name) |
| 92 | + |
| 93 | + def test_with_datamodule_no_overridden(self, hook_name): |
| 94 | + model, dm, trainer = self.reset_instances() |
| 95 | + trainer._data_connector.attach_datamodule(model, datamodule=dm) |
| 96 | + with no_warning_call(match="have overridden `{hook_name}` in both"): |
| 97 | + hook = trainer._data_connector._datahook_selector.get_hook(hook_name) |
| 98 | + |
| 99 | + assert hook == getattr(model, hook_name) |
| 100 | + |
| 101 | + def test_override_model_hook(self, hook_name): |
| 102 | + model, dm, trainer = self.reset_instances() |
| 103 | + trainer._data_connector.attach_datamodule(model, datamodule=dm) |
| 104 | + with no_warning_call(match="have overridden `{hook_name}` in both"): |
| 105 | + hook = trainer._data_connector._datahook_selector.get_hook(hook_name) |
| 106 | + |
| 107 | + assert hook == getattr(model, hook_name) |
| 108 | + |
| 109 | + def test_override_datamodule_hook(self, hook_name): |
| 110 | + model, dm, trainer = self.reset_instances() |
| 111 | + trainer._data_connector.attach_datamodule(model, datamodule=dm) |
| 112 | + setattr(dm, hook_name, self.overridden_func) |
| 113 | + with no_warning_call(match="have overridden `{hook_name}` in both"): |
| 114 | + hook = trainer._data_connector._datahook_selector.get_hook(hook_name) |
| 115 | + |
| 116 | + assert hook == getattr(dm, hook_name) |
| 117 | + |
| 118 | + def test_override_both_model_and_datamodule(self, hook_name): |
| 119 | + model, dm, trainer = self.reset_instances() |
| 120 | + trainer._data_connector.attach_datamodule(model, datamodule=dm) |
| 121 | + setattr(model, hook_name, self.overridden_func) |
| 122 | + setattr(dm, hook_name, self.overridden_func) |
| 123 | + with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"): |
| 124 | + hook = trainer._data_connector._datahook_selector.get_hook(hook_name) |
| 125 | + |
| 126 | + warning_cache.clear() |
| 127 | + assert hook == getattr(dm, hook_name) |
| 128 | + |
| 129 | + def test_with_datamodule_override_model(self, hook_name): |
| 130 | + model, dm, trainer = self.reset_instances() |
| 131 | + trainer._data_connector.attach_datamodule(model, datamodule=dm) |
| 132 | + setattr(model, hook_name, self.overridden_func) |
| 133 | + with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"): |
| 134 | + hook = trainer._data_connector._datahook_selector.get_hook(hook_name) |
| 135 | + |
| 136 | + warning_cache.clear() |
| 137 | + assert hook == getattr(model, hook_name) |
| 138 | + |
| 139 | + |
| 140 | +def test_invalid_hook_passed_in_datahook_selector(): |
| 141 | + dh_selector = _DataHookSelector(BoringModel(), None) |
| 142 | + with pytest.raises(ValueError, match="is not a shared hook"): |
| 143 | + dh_selector.get_hook("setup") |
| 144 | + |
| 145 | + |
74 | 146 | def test_eval_distributed_sampler_warning(tmpdir): |
75 | 147 | """Test that a warning is raised when `DistributedSampler` is used with evaluation.""" |
76 | 148 |
|
|
0 commit comments