Skip to content

Commit 5d2d9b0

Browse files
authored
Avoid patching common DataHooks to the LightningModule (#10603)
1 parent 29d5afb commit 5d2d9b0

File tree

6 files changed

+147
-21
lines changed

6 files changed

+147
-21
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
514514
- Removed `DeepSpeedPlugin.{precision,amp_type,amp_level}` properties ([#10657](https://github.com/PyTorchLightning/pytorch-lightning/pull/10657))
515515

516516

517+
- Removed patching of `on_before_batch_transfer`, `transfer_batch_to_device` and `on_after_batch_transfer` hooks in `LightningModule` ([#10603](https://github.com/PyTorchLightning/pytorch-lightning/pull/10603))
518+
519+
517520
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
518521

519522

pytorch_lightning/core/lightning.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pytorch_lightning.core.optimizer import LightningOptimizer
3939
from pytorch_lightning.core.saving import ModelIO
4040
from pytorch_lightning.loggers import LightningLoggerBase
41+
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector
4142
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
4243
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
4344
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
@@ -259,9 +260,16 @@ def _apply_batch_transfer_handler(
259260
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
260261
) -> Any:
261262
device = device or self.device
262-
batch = self.on_before_batch_transfer(batch, dataloader_idx)
263-
batch = self.transfer_batch_to_device(batch, device, dataloader_idx)
264-
batch = self.on_after_batch_transfer(batch, dataloader_idx)
263+
datahook_selector = (
264+
_DataHookSelector(self, None) if self.trainer is None else self.trainer._data_connector._datahook_selector
265+
)
266+
267+
hook = datahook_selector.get_hook("on_before_batch_transfer")
268+
batch = hook(batch, dataloader_idx)
269+
hook = datahook_selector.get_hook("transfer_batch_to_device")
270+
batch = hook(batch, device, dataloader_idx)
271+
hook = datahook_selector.get_hook("on_after_batch_transfer")
272+
batch = hook(batch, dataloader_idx)
265273
return batch
266274

267275
def print(self, *args, **kwargs) -> None:

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,11 @@ def __verify_dp_batch_transfer_support(trainer: "pl.Trainer", model: "pl.Lightni
206206
"""Raise Misconfiguration exception since these hooks are not supported in DP mode."""
207207
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
208208
batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
209+
datahook_selector = trainer._data_connector._datahook_selector
209210
for hook in batch_transfer_hooks:
210-
if trainer._accelerator_connector.use_dp and is_overridden(hook, model):
211+
if trainer._accelerator_connector.use_dp and (
212+
is_overridden(hook, datahook_selector.model) or is_overridden(hook, datahook_selector.datamodule)
213+
):
211214
raise MisconfigurationException(f"Overriding `{hook}` is not supported in DP mode.")
212215

213216

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414
import multiprocessing
1515
import os
16-
from dataclasses import dataclass
17-
from typing import Any, Collection, List, Optional, Tuple, Union
16+
from dataclasses import dataclass, field
17+
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
1818
from weakref import proxy
1919

2020
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
@@ -40,7 +40,9 @@
4040
from pytorch_lightning.utilities.model_helpers import is_overridden
4141
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
4242
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
43-
from pytorch_lightning.utilities.warnings import PossibleUserWarning
43+
from pytorch_lightning.utilities.warnings import PossibleUserWarning, WarningCache
44+
45+
warning_cache = WarningCache()
4446

4547

4648
class DataConnector:
@@ -52,6 +54,8 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_
5254
self._test_dataloader_source = _DataLoaderSource(None, "")
5355
self._predict_dataloader_source = _DataLoaderSource(None, "")
5456

57+
self._datahook_selector = _DataHookSelector(None, None)
58+
5559
@property
5660
def _should_reload_train_dl(self) -> bool:
5761
"""Check if train dataloader should be reloaded."""
@@ -192,6 +196,8 @@ def attach_datamodule(
192196
self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None
193197
) -> None:
194198
# If we have a datamodule, attach necessary hooks + dataloaders
199+
self._datahook_selector = _DataHookSelector(model, datamodule)
200+
195201
if datamodule is None:
196202
return
197203

@@ -200,12 +206,6 @@ def attach_datamodule(
200206
self._test_dataloader_source = _DataLoaderSource(datamodule, "test_dataloader")
201207
self._predict_dataloader_source = _DataLoaderSource(datamodule, "predict_dataloader")
202208

203-
# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
204-
batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
205-
for hook in batch_transfer_hooks:
206-
if is_overridden(hook, datamodule):
207-
setattr(model, hook, getattr(datamodule, hook))
208-
209209
self.trainer.datamodule = datamodule
210210
datamodule.trainer = self.trainer
211211

@@ -555,3 +555,48 @@ def is_module(self) -> bool:
555555
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
556556

557557
return isinstance(self.instance, (LightningModule, LightningDataModule))
558+
559+
560+
@dataclass
561+
class _DataHookSelector:
562+
"""Stores the info about the shared DataHooks within LightningModule and LightningDataModule.
563+
564+
The hook source can be
565+
566+
1. a method from the :class:`~pytorch_lightning.core.lightning.LightningModule`,
567+
2. a method from the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
568+
569+
Arguments:
570+
model: A LightningModule
571+
datamodule: A LightningDataModule
572+
"""
573+
574+
model: "pl.LightningModule"
575+
datamodule: Optional["pl.LightningDataModule"]
576+
_valid_hooks: Tuple[str] = field(
577+
default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
578+
)
579+
580+
def get_hook(self, hook_name: str) -> Callable:
581+
if hook_name not in self._valid_hooks:
582+
raise ValueError(
583+
f"`{hook_name}` is not a shared hook within `LightningModule` and `LightningDataModule`."
584+
f" Valid hooks are {self._valid_hooks}."
585+
)
586+
587+
if self.datamodule is None:
588+
return getattr(self.model, hook_name)
589+
590+
if is_overridden(hook_name, self.datamodule):
591+
if is_overridden(hook_name, self.model):
592+
warning_cache.warn(
593+
f"You have overridden `{hook_name}` in both `LightningModule` and `LightningDataModule`."
594+
" It will use the implementation from `LightningDataModule` instance."
595+
)
596+
return getattr(self.datamodule, hook_name)
597+
598+
warning_cache.warn(
599+
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
600+
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
601+
)
602+
return getattr(self.model, hook_name)

tests/core/test_datamodules.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from pytorch_lightning.trainer.states import TrainerFn
2727
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
29-
from pytorch_lightning.utilities.model_helpers import is_overridden
3029
from tests.helpers import BoringDataModule, BoringModel
3130
from tests.helpers.datamodules import ClassifDataModule
3231
from tests.helpers.runif import RunIf
@@ -309,15 +308,11 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
309308
batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long)))
310309

311310
trainer = Trainer(accelerator="gpu", devices=1)
311+
model.trainer = trainer
312312
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
313313
get_module_mock.return_value = model
314-
if is_overridden("transfer_batch_to_device", dm):
315-
model.transfer_batch_to_device = dm.transfer_batch_to_device
316-
317-
model.on_before_batch_transfer = dm.on_before_batch_transfer
318-
model.transfer_batch_to_device = dm.transfer_batch_to_device
319-
model.on_after_batch_transfer = dm.on_after_batch_transfer
320314

315+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
321316
batch_gpu = trainer.strategy.batch_to_device(batch, expected_device)
322317

323318
assert dm.on_before_batch_transfer_hook_rank == 0

tests/trainer/connectors/test_data_connector.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from torch.utils.data import DataLoader
1818

1919
from pytorch_lightning import Trainer
20-
from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource
20+
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache
2121
from pytorch_lightning.trainer.states import TrainerFn
2222
from pytorch_lightning.utilities.warnings import PossibleUserWarning
23+
from tests.deprecated_api import no_warning_call
2324
from tests.helpers import BoringDataModule, BoringModel
2425
from tests.helpers.boring_model import RandomDataset
2526

@@ -71,6 +72,77 @@ def test_dataloader_source_request_from_module():
7172
module.foo.assert_called_once()
7273

7374

75+
@pytest.mark.parametrize(
76+
"hook_name", ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
77+
)
78+
class TestDataHookSelector:
79+
def overridden_func(self, batch, *args, **kwargs):
80+
return batch
81+
82+
def reset_instances(self):
83+
return BoringDataModule(), BoringModel(), Trainer()
84+
85+
def test_no_datamodule_no_overridden(self, hook_name):
86+
model, _, trainer = self.reset_instances()
87+
trainer._data_connector.attach_datamodule(model, datamodule=None)
88+
with no_warning_call(match="have overridden `{hook_name}` in both"):
89+
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
90+
91+
assert hook == getattr(model, hook_name)
92+
93+
def test_with_datamodule_no_overridden(self, hook_name):
94+
model, dm, trainer = self.reset_instances()
95+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
96+
with no_warning_call(match="have overridden `{hook_name}` in both"):
97+
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
98+
99+
assert hook == getattr(model, hook_name)
100+
101+
def test_override_model_hook(self, hook_name):
102+
model, dm, trainer = self.reset_instances()
103+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
104+
with no_warning_call(match="have overridden `{hook_name}` in both"):
105+
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
106+
107+
assert hook == getattr(model, hook_name)
108+
109+
def test_override_datamodule_hook(self, hook_name):
110+
model, dm, trainer = self.reset_instances()
111+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
112+
setattr(dm, hook_name, self.overridden_func)
113+
with no_warning_call(match="have overridden `{hook_name}` in both"):
114+
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
115+
116+
assert hook == getattr(dm, hook_name)
117+
118+
def test_override_both_model_and_datamodule(self, hook_name):
119+
model, dm, trainer = self.reset_instances()
120+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
121+
setattr(model, hook_name, self.overridden_func)
122+
setattr(dm, hook_name, self.overridden_func)
123+
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"):
124+
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
125+
126+
warning_cache.clear()
127+
assert hook == getattr(dm, hook_name)
128+
129+
def test_with_datamodule_override_model(self, hook_name):
130+
model, dm, trainer = self.reset_instances()
131+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
132+
setattr(model, hook_name, self.overridden_func)
133+
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"):
134+
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
135+
136+
warning_cache.clear()
137+
assert hook == getattr(model, hook_name)
138+
139+
140+
def test_invalid_hook_passed_in_datahook_selector():
141+
dh_selector = _DataHookSelector(BoringModel(), None)
142+
with pytest.raises(ValueError, match="is not a shared hook"):
143+
dh_selector.get_hook("setup")
144+
145+
74146
def test_eval_distributed_sampler_warning(tmpdir):
75147
"""Test that a warning is raised when `DistributedSampler` is used with evaluation."""
76148

0 commit comments

Comments
 (0)