Skip to content

Commit 0c2d88f

Browse files
committed
add hooks
1 parent db69d16 commit 0c2d88f

File tree

9 files changed

+36
-34
lines changed

9 files changed

+36
-34
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def train_or_test(self):
6969
def batch_to_device(self, batch: Any, device: torch.device):
7070
model = self.trainer.get_model()
7171
if model is not None:
72-
return model.transfer_batch_to_device(batch, device)
72+
return model.prepare_batch_for_transfer(batch, device)
7373
return move_data_to_device(batch, device)
7474

7575
def training_step_end(self, output):

pytorch_lightning/core/datamodule.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ def wrapped_fn(*args, **kwargs):
9494
return wrapped_fn
9595

9696

97-
class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper):
97+
class LightningDataModule(
98+
CheckpointHooks,
99+
DataHooks,
100+
metaclass=_DataModuleWrapper
101+
):
98102
"""
99103
A DataModule standardizes the training, val, test splits, data preparation and transforms.
100104
The main advantage is consistent data splits, data preparation and transforms across models.
@@ -247,22 +251,6 @@ def prepare_data(self, *args, **kwargs):
247251
def setup(self, stage: Optional[str] = None):
248252
pass
249253

250-
@abstractmethod
251-
def train_dataloader(self, *args, **kwargs) -> DataLoader:
252-
pass
253-
254-
@abstractmethod
255-
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
256-
pass
257-
258-
@abstractmethod
259-
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
260-
pass
261-
262-
@abstractmethod
263-
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
264-
pass
265-
266254
@classmethod
267255
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
268256
r"""Extends existing argparse by default `LightningDataModule` attributes.

pytorch_lightning/core/decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def auto_transfer_args(self, *args, **kwargs):
5959
if not isinstance(self, LightningModule):
6060
return fn(self, *args, **kwargs)
6161

62-
args = self.transfer_batch_to_device(args, self.device)
63-
kwargs = self.transfer_batch_to_device(kwargs, self.device)
62+
args = self.transfer_batch_to_device(args)
63+
kwargs = self.transfer_batch_to_device(kwargs)
6464
return fn(self, *args, **kwargs)
6565

6666
return auto_transfer_args

pytorch_lightning/core/hooks.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Various hooks to be used in the Lightning code."""
1616

17-
from typing import Any, Dict, List, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import torch
2020
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
@@ -300,7 +300,7 @@ def on_after_backward(self):
300300

301301

302302
class DataHooks:
303-
"""Hooks to be used with LightningDataModule."""
303+
"""Hooks to be used for data related stuff."""
304304
def prepare_data(self) -> None:
305305
"""
306306
Use this to download and prepare data.
@@ -508,7 +508,7 @@ def val_dataloader(self):
508508
will have an argument ``dataloader_idx`` which matches the order here.
509509
"""
510510

511-
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
511+
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
512512
"""
513513
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
514514
wrapped in a custom data structure.
@@ -556,8 +556,21 @@ def transfer_batch_to_device(self, batch, device)
556556
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
557557
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
558558
"""
559+
device = device or self.device
559560
return move_data_to_device(batch, device)
560561

562+
def on_before_batch_transfer(self, batch):
563+
return batch
564+
565+
def on_after_batch_transfer(self, batch):
566+
return batch
567+
568+
def prepare_batch_for_transfer(self, batch: Any, device: Optional[torch.device] = None):
569+
batch = self.on_before_batch_transfer(batch)
570+
batch = self.transfer_batch_to_device(batch, device)
571+
batch = self.on_after_batch_transfer(batch)
572+
return batch
573+
561574

562575
class CheckpointHooks:
563576
"""Hooks to be used with Checkpointing."""

pytorch_lightning/core/lightning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@
5555

5656
class LightningModule(
5757
ABC,
58+
CheckpointHooks,
59+
DataHooks,
5860
DeviceDtypeModuleMixin,
5961
GradInformation,
60-
ModelIO,
6162
ModelHooks,
62-
DataHooks,
63-
CheckpointHooks,
63+
ModelIO,
6464
Module,
6565
):
6666
# Below is for property support of JIT in PyTorch 1.7
@@ -1787,7 +1787,7 @@ def to_torchscript(
17871787
if example_inputs is None:
17881788
example_inputs = self.example_input_array
17891789
# automatically send example inputs to the right device and use trace
1790-
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
1790+
example_inputs = self.prepare_batch_for_transfer(example_inputs)
17911791
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
17921792
else:
17931793
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"

pytorch_lightning/core/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _forward_example_input(self) -> None:
219219
trainer = self._model.trainer
220220

221221
input_ = model.example_input_array
222-
input_ = model.transfer_batch_to_device(input_, model.device)
222+
input_ = model.prepare_batch_for_transfer(input_, model.device)
223223

224224
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu:
225225
model.forward = torch.cuda.amp.autocast()(model.forward)

pytorch_lightning/loggers/tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def log_graph(self, model: LightningModule, input_array=None):
198198
input_array = model.example_input_array
199199

200200
if input_array is not None:
201-
input_array = model.transfer_batch_to_device(input_array, model.device)
201+
input_array = model.prepare_batch_for_transfer(input_array)
202202
self.experiment.add_graph(model, input_array)
203203
else:
204204
rank_zero_warn('Could not log computational graph since the'

pytorch_lightning/loggers/test_tube.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,8 @@ def log_graph(self, model: LightningModule, input_array=None):
146146
input_array = model.example_input_array
147147

148148
if input_array is not None:
149-
self.experiment.add_graph(
150-
model,
151-
model.transfer_batch_to_device(
152-
model.example_input_array, model.device)
153-
)
149+
input_array = self.prepare_batch_for_transfer(input_array)
150+
self.experiment.add_graph(model, input_array)
154151
else:
155152
rank_zero_warn('Could not log computational graph since the'
156153
' `model.example_input_array` attribute is not set'

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,12 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], st
118118
model.test_dataloader = datamodule.test_dataloader
119119

120120
# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
121+
if is_overridden('on_before_batch_transfer', datamodule):
122+
model.on_before_batch_transfer = datamodule.on_before_batch_transfer
121123
if is_overridden('transfer_batch_to_device', datamodule):
122124
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
125+
if is_overridden('on_after_batch_transfer', datamodule):
126+
model.on_after_batch_transfer = datamodule.on_after_batch_transfer
123127

124128
self.trainer.datamodule = datamodule
125129
datamodule.trainer = self.trainer

0 commit comments

Comments
 (0)