Skip to content

Commit bcc0004

Browse files
authored
Add before_batch_transfer and after_batch_transfer hooks (#3671)
* add hooks * comment * docs * add tests * make it private * fix tests * docs * chlog * testcode * codefactor * fix doctest * fix doctest * suggestions * is always overriden * pep and BoringModel * BoringModel * docs * docs * docs * fix * rebase * rebase * suggestions * docs * suggestions * try fix docs * docs * update name * yapf * docs * rebase * yapf
1 parent b019c25 commit bcc0004

File tree

15 files changed

+238
-106
lines changed

15 files changed

+238
-106
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7676
- Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579))
7777

7878

79+
- Added `on_before_batch_transfer` and `on_after_batch_transfer` data hooks ([#3671](https://github.com/PyTorchLightning/pytorch-lightning/pull/3671))
80+
81+
7982
- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479))
8083

8184

docs/source/common/lightning_module.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,3 +1209,15 @@ transfer_batch_to_device
12091209

12101210
.. automethod:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device
12111211
:noindex:
1212+
1213+
on_before_batch_transfer
1214+
~~~~~~~~~~~~~~~~~~~~~~~~
1215+
1216+
.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_before_batch_transfer
1217+
:noindex:
1218+
1219+
on_after_batch_transfer
1220+
~~~~~~~~~~~~~~~~~~~~~~~
1221+
1222+
.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer
1223+
:noindex:

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def package_list_from_file(file):
374374
import torch
375375
from torch import nn
376376
import pytorch_lightning as pl
377-
from pytorch_lightning import LightningModule, Trainer
377+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
378378
from pytorch_lightning.utilities import (
379379
_NATIVE_AMP_AVAILABLE,
380380
_APEX_AVAILABLE,

docs/source/extensions/datamodules.rst

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
164164
def test_dataloader(self):
165165
return DataLoader(self.mnist_test, batch_size=32)
166166
167+
167168
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.
168169

169170

@@ -179,6 +180,7 @@ To define a DataModule define 5 methods:
179180
- val_dataloader(s)
180181
- test_dataloader(s)
181182

183+
182184
prepare_data
183185
^^^^^^^^^^^^
184186
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed
@@ -196,7 +198,9 @@ settings.
196198
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
197199
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
198200
199-
.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`).
201+
202+
.. warning:: ``prepare_data`` is called from a single GPU. Do not use it to assign state (``self.x = y``).
203+
200204

201205
setup
202206
^^^^^
@@ -269,6 +273,7 @@ Use this method to generate the val dataloader. Usually you just wrap the datas
269273
def val_dataloader(self):
270274
return DataLoader(self.mnist_val, batch_size=64)
271275
276+
272277
.. _datamodule-test-dataloader-label:
273278

274279
test_dataloader
@@ -284,24 +289,59 @@ Use this method to generate the test dataloader. Usually you just wrap the datas
284289
def test_dataloader(self):
285290
return DataLoader(self.mnist_test, batch_size=64)
286291
292+
287293
transfer_batch_to_device
288294
^^^^^^^^^^^^^^^^^^^^^^^^
289-
Override to define how you want to move an arbitrary batch to a device
295+
Override to define how you want to move an arbitrary batch to a device.
290296

291-
.. code-block:: python
297+
.. testcode::
292298

293-
import pytorch_lightning as pl
294-
295-
296-
class MNISTDataModule(pl.LightningDataModule):
299+
class MNISTDataModule(LightningDataModule):
297300
def transfer_batch_to_device(self, batch, device):
298301
x = batch['x']
299302
x = CustomDataWrapper(x)
300-
batch['x'].to(device)
303+
batch['x'] = x.to(device)
301304
return batch
302305

303306

304-
.. note:: To decouple your data from transforms you can parametrize them via `__init__`.
307+
.. note:: This hook only runs on single GPU training and DDP (no data-parallel).
308+
309+
310+
on_before_batch_transfer
311+
^^^^^^^^^^^^^^^^^^^^^^^^
312+
Override to alter or apply augmentations to your batch before it is transferred to the device.
313+
314+
.. testcode::
315+
316+
class MNISTDataModule(LightningDataModule):
317+
def on_before_batch_transfer(self, batch):
318+
batch['x'] = transforms(batch['x'])
319+
return batch
320+
321+
322+
.. note:: This hook only runs on single GPU training and DDP (no data-parallel).
323+
324+
325+
on_after_batch_transfer
326+
^^^^^^^^^^^^^^^^^^^^^^^
327+
Override to alter or apply augmentations to your batch after it is transferred to the device.
328+
329+
.. testcode::
330+
331+
class MNISTDataModule(LightningDataModule):
332+
def on_after_batch_transfer(self, batch):
333+
batch['x'] = gpu_transforms(batch['x'])
334+
return batch
335+
336+
337+
.. note::
338+
This hook only runs on single GPU training and DDP (no data-parallel). This hook
339+
will also be called when using CPU device, so adding augmentations here or in
340+
``on_before_batch_transfer`` means the same thing.
341+
342+
343+
344+
.. note:: To decouple your data from transforms you can parametrize them via ``__init__``.
305345

306346
.. code-block:: python
307347

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def teardown(self):
125125
"""
126126
pass
127127

128-
def batch_to_device(self, batch: Any, device: torch.device) -> Any:
128+
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
129129
"""Moves the batch to the correct device.
130130
The returned batch is of the same type as the input batch, just having all tensors on the correct device.
131131
@@ -134,8 +134,10 @@ def batch_to_device(self, batch: Any, device: torch.device) -> Any:
134134
device: The target device
135135
"""
136136
model = self.lightning_module
137+
137138
if model is not None:
138-
return model.transfer_batch_to_device(batch, device)
139+
return model._apply_batch_transfer_handler(batch, device)
140+
139141
return move_data_to_device(batch, device)
140142

141143
def on_train_start(self):
@@ -155,9 +157,7 @@ def training_step(self, args):
155157
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
156158
157159
"""
158-
batch = self.to_device(args[0])
159-
160-
args[0] = batch
160+
args[0] = self.to_device(args[0])
161161

162162
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
163163
return self.training_type_plugin.training_step(*args)

pytorch_lightning/core/datamodule.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from argparse import ArgumentParser, Namespace
2020
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
2121

22-
import torch
2322
from torch.utils.data import DataLoader, Dataset
2423

2524
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
@@ -95,7 +94,7 @@ def wrapped_fn(*args, **kwargs):
9594
return wrapped_fn
9695

9796

98-
class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper):
97+
class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper):
9998
"""
10099
A DataModule standardizes the training, val, test splits, data preparation and transforms.
101100
The main advantage is consistent data splits, data preparation and transforms across models.
@@ -248,26 +247,6 @@ def prepare_data(self, *args, **kwargs):
248247
def setup(self, stage: Optional[str] = None):
249248
pass
250249

251-
@abstractmethod
252-
def train_dataloader(self, *args, **kwargs) -> DataLoader:
253-
pass
254-
255-
@abstractmethod
256-
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
257-
pass
258-
259-
@abstractmethod
260-
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
261-
pass
262-
263-
@abstractmethod
264-
def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
265-
pass
266-
267-
@abstractmethod
268-
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
269-
pass
270-
271250
@classmethod
272251
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
273252
r"""Extends existing argparse by default `LightningDataModule` attributes."""

pytorch_lightning/core/decorators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def auto_transfer_args(self, *args, **kwargs):
5858
if not isinstance(self, LightningModule):
5959
return fn(self, *args, **kwargs)
6060

61-
args = self.transfer_batch_to_device(args, self.device)
62-
kwargs = self.transfer_batch_to_device(kwargs, self.device)
61+
args, kwargs = self.transfer_batch_to_device((args, kwargs))
6362
return fn(self, *args, **kwargs)
6463

6564
return auto_transfer_args

pytorch_lightning/core/hooks.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def on_post_move_to_device(self):
336336

337337

338338
class DataHooks:
339-
"""Hooks to be used with LightningDataModule."""
339+
"""Hooks to be used for data related stuff."""
340340

341341
def prepare_data(self) -> None:
342342
"""
@@ -580,9 +580,27 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =
580580
581581
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
582582
583+
Note:
584+
This hook should only transfer the data and not modify it, nor should it move the data to
585+
any other device than the one passed in as argument (unless you know what you are doing).
586+
587+
Note:
588+
This hook only runs on single GPU training and DDP (no data-parallel).
589+
If you need multi-GPU support for your custom batch objects, you need to define your custom
590+
:class:`~torch.nn.parallel.DistributedDataParallel` or
591+
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
592+
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
593+
594+
Args:
595+
batch: A batch of data that needs to be transferred to a new device.
596+
device: The target device as defined in PyTorch.
597+
598+
Returns:
599+
A reference to the data on the new device.
600+
583601
Example::
584602
585-
def transfer_batch_to_device(self, batch, device)
603+
def transfer_batch_to_device(self, batch, device):
586604
if isinstance(batch, CustomBatch):
587605
# move all tensors in your custom data structure to the device
588606
batch.samples = batch.samples.to(device)
@@ -591,29 +609,62 @@ def transfer_batch_to_device(self, batch, device)
591609
batch = super().transfer_batch_to_device(data, device)
592610
return batch
593611
612+
See Also:
613+
- :meth:`move_data_to_device`
614+
- :meth:`apply_to_collection`
615+
"""
616+
device = device or self.device
617+
return move_data_to_device(batch, device)
618+
619+
def on_before_batch_transfer(self, batch):
620+
"""
621+
Override to alter or apply batch augmentations to your batch before it is transferred to the device.
622+
623+
Note:
624+
This hook only runs on single GPU training and DDP (no data-parallel).
625+
594626
Args:
595-
batch: A batch of data that needs to be transferred to a new device.
596-
device: The target device as defined in PyTorch.
627+
batch: A batch of data that needs to be altered or augmented.
597628
598629
Returns:
599-
A reference to the data on the new device.
630+
A batch of data
600631
601-
Note:
602-
This hook should only transfer the data and not modify it, nor should it move the data to
603-
any other device than the one passed in as argument (unless you know what you are doing).
632+
Example::
633+
634+
def on_before_batch_transfer(self, batch):
635+
batch['x'] = transforms(batch['x'])
636+
return batch
637+
638+
See Also:
639+
- :meth:`on_after_batch_transfer`
640+
- :meth:`transfer_batch_to_device`
641+
"""
642+
return batch
643+
644+
def on_after_batch_transfer(self, batch):
645+
"""
646+
Override to alter or apply batch augmentations to your batch after it is transferred to the device.
604647
605648
Note:
606649
This hook only runs on single GPU training and DDP (no data-parallel).
607-
If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``,
608-
you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or
609-
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
650+
651+
Args:
652+
batch: A batch of data that needs to be altered or augmented.
653+
654+
Returns:
655+
A batch of data
656+
657+
Example::
658+
659+
def on_after_batch_transfer(self, batch):
660+
batch['x'] = gpu_transforms(batch['x'])
661+
return batch
610662
611663
See Also:
612-
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
613-
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
664+
- :meth:`on_before_batch_transfer`
665+
- :meth:`transfer_batch_to_device`
614666
"""
615-
device = device or self.device
616-
return move_data_to_device(batch, device)
667+
return batch
617668

618669

619670
class CheckpointHooks:
@@ -627,7 +678,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
627678
Args:
628679
checkpoint: Loaded checkpoint
629680
630-
631681
Example::
632682
633683
def on_load_checkpoint(self, checkpoint):

pytorch_lightning/core/lightning.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,12 @@ def logger(self):
179179
""" Reference to the logger object in the Trainer. """
180180
return self.trainer.logger if self.trainer else None
181181

182+
def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None):
183+
batch = self.on_before_batch_transfer(batch)
184+
batch = self.transfer_batch_to_device(batch, device)
185+
batch = self.on_after_batch_transfer(batch)
186+
return batch
187+
182188
def print(self, *args, **kwargs) -> None:
183189
r"""
184190
Prints only from process 0. Use this in any distributed mode to log only once.
@@ -1697,7 +1703,7 @@ def to_onnx(
16971703
)
16981704
input_sample = self.example_input_array
16991705

1700-
input_sample = self.transfer_batch_to_device(input_sample)
1706+
input_sample = self._apply_batch_transfer_handler(input_sample)
17011707

17021708
if "example_outputs" not in kwargs:
17031709
self.eval()
@@ -1768,18 +1774,15 @@ def to_torchscript(
17681774
if self.example_input_array is None:
17691775
raise ValueError(
17701776
'Choosing method=`trace` requires either `example_inputs`'
1771-
' or `model.example_input_array` to be defined'
1777+
' or `model.example_input_array` to be defined.'
17721778
)
17731779
example_inputs = self.example_input_array
17741780

17751781
# automatically send example inputs to the right device and use trace
1776-
example_inputs = self.transfer_batch_to_device(example_inputs)
1782+
example_inputs = self._apply_batch_transfer_handler(example_inputs)
17771783
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
17781784
else:
1779-
raise ValueError(
1780-
"The 'method' parameter only supports 'script' or 'trace',"
1781-
f" but value given was: {method}"
1782-
)
1785+
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")
17831786

17841787
self.train(mode)
17851788

0 commit comments

Comments
 (0)