Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579))


- Added `on_before_batch_transfer` and `on_after_batch_transfer` data hooks ([#3671](https://github.com/PyTorchLightning/pytorch-lightning/pull/3671))


- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479))


Expand Down
12 changes: 12 additions & 0 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1209,3 +1209,15 @@ transfer_batch_to_device

.. automethod:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device
:noindex:

on_before_batch_transfer
~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_before_batch_transfer
:noindex:

on_after_batch_transfer
~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer
:noindex:
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def package_list_from_file(file):
import torch
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.utilities import (
_NATIVE_AMP_AVAILABLE,
_APEX_AVAILABLE,
Expand Down
58 changes: 49 additions & 9 deletions docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)


.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.


Expand All @@ -179,6 +180,7 @@ To define a DataModule define 5 methods:
- val_dataloader(s)
- test_dataloader(s)


prepare_data
^^^^^^^^^^^^
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed
Expand All @@ -196,7 +198,9 @@ settings.
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`).

.. warning:: ``prepare_data`` is called from a single GPU. Do not use it to assign state (``self.x = y``).


setup
^^^^^
Expand Down Expand Up @@ -269,6 +273,7 @@ Use this method to generate the val dataloader. Usually you just wrap the datas
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=64)


.. _datamodule-test-dataloader-label:

test_dataloader
Expand All @@ -284,24 +289,59 @@ Use this method to generate the test dataloader. Usually you just wrap the datas
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)


transfer_batch_to_device
^^^^^^^^^^^^^^^^^^^^^^^^
Override to define how you want to move an arbitrary batch to a device
Override to define how you want to move an arbitrary batch to a device.

.. code-block:: python
.. testcode::

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
class MNISTDataModule(LightningDataModule):
def transfer_batch_to_device(self, batch, device):
x = batch['x']
x = CustomDataWrapper(x)
batch['x'].to(device)
batch['x'] = x.to(device)
return batch


.. note:: To decouple your data from transforms you can parametrize them via `__init__`.
.. note:: This hook only runs on single GPU training and DDP (no data-parallel).


on_before_batch_transfer
^^^^^^^^^^^^^^^^^^^^^^^^
Override to alter or apply augmentations to your batch before it is transferred to the device.

.. testcode::

class MNISTDataModule(LightningDataModule):
def on_before_batch_transfer(self, batch):
batch['x'] = transforms(batch['x'])
return batch


.. note:: This hook only runs on single GPU training and DDP (no data-parallel).


on_after_batch_transfer
^^^^^^^^^^^^^^^^^^^^^^^
Override to alter or apply augmentations to your batch after it is transferred to the device.

.. testcode::

class MNISTDataModule(LightningDataModule):
def on_after_batch_transfer(self, batch):
batch['x'] = gpu_transforms(batch['x'])
return batch


.. note::
This hook only runs on single GPU training and DDP (no data-parallel). This hook
will also be called when using CPU device, so adding augmentations here or in
``on_before_batch_transfer`` means the same thing.



.. note:: To decouple your data from transforms you can parametrize them via ``__init__``.

.. code-block:: python

Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def teardown(self):
"""
pass

def batch_to_device(self, batch: Any, device: torch.device) -> Any:
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""Moves the batch to the correct device.
The returned batch is of the same type as the input batch, just having all tensors on the correct device.

Expand All @@ -114,8 +114,10 @@ def batch_to_device(self, batch: Any, device: torch.device) -> Any:
device: The target device
"""
model = self.lightning_module

if model is not None:
return model.transfer_batch_to_device(batch, device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't look backward compatible.

cc @Borda @tchaton

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return model._apply_batch_transfer_handler(batch, device)

return move_data_to_device(batch, device)

def on_train_start(self):
Expand All @@ -135,9 +137,7 @@ def training_step(self, args):
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.

"""
batch = self.to_device(args[0])

args[0] = batch
args[0] = self.to_device(args[0])

with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)
Expand Down
23 changes: 1 addition & 22 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
Expand Down Expand Up @@ -95,7 +94,7 @@ def wrapped_fn(*args, **kwargs):
return wrapped_fn


class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper):
class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
Expand Down Expand Up @@ -248,26 +247,6 @@ def prepare_data(self, *args, **kwargs):
def setup(self, stage: Optional[str] = None):
pass

@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
pass

@abstractmethod
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass

@abstractmethod
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass

@abstractmethod
def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
pass

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes."""
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def auto_transfer_args(self, *args, **kwargs):
if not isinstance(self, LightningModule):
return fn(self, *args, **kwargs)

args = self.transfer_batch_to_device(args, self.device)
kwargs = self.transfer_batch_to_device(kwargs, self.device)
args, kwargs = self.transfer_batch_to_device((args, kwargs))
return fn(self, *args, **kwargs)

return auto_transfer_args
82 changes: 66 additions & 16 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def on_after_backward(self):


class DataHooks:
"""Hooks to be used with LightningDataModule."""
"""Hooks to be used for data related stuff."""

def prepare_data(self) -> None:
"""
Expand Down Expand Up @@ -564,9 +564,27 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).

Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).

Note:
This hook only runs on single GPU training and DDP (no data-parallel).
If you need multi-GPU support for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.

Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.

Returns:
A reference to the data on the new device.

Example::

def transfer_batch_to_device(self, batch, device)
def transfer_batch_to_device(self, batch, device):
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
Expand All @@ -575,29 +593,62 @@ def transfer_batch_to_device(self, batch, device)
batch = super().transfer_batch_to_device(data, device)
return batch

See Also:
- :meth:`move_data_to_device`
- :meth:`apply_to_collection`
"""
device = device or self.device
return move_data_to_device(batch, device)

def on_before_batch_transfer(self, batch):
"""
Override to alter or apply batch augmentations to your batch before it is transferred to the device.

Note:
This hook only runs on single GPU training and DDP (no data-parallel).

Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
batch: A batch of data that needs to be altered or augmented.

Returns:
A reference to the data on the new device.
A batch of data

Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
Example::

def on_before_batch_transfer(self, batch):
batch['x'] = transforms(batch['x'])
return batch

See Also:
- :meth:`on_after_batch_transfer`
- :meth:`transfer_batch_to_device`
"""
return batch

def on_after_batch_transfer(self, batch):
"""
Override to alter or apply batch augmentations to your batch after it is transferred to the device.

Note:
This hook only runs on single GPU training and DDP (no data-parallel).
If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``,
you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.

Args:
batch: A batch of data that needs to be altered or augmented.

Returns:
A batch of data

Example::

def on_after_batch_transfer(self, batch):
batch['x'] = gpu_transforms(batch['x'])
return batch

See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
- :meth:`on_before_batch_transfer`
- :meth:`transfer_batch_to_device`
"""
device = device or self.device
return move_data_to_device(batch, device)
return batch


class CheckpointHooks:
Expand All @@ -611,7 +662,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
Args:
checkpoint: Loaded checkpoint


Example::

def on_load_checkpoint(self, checkpoint):
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def logger(self):
""" Reference to the logger object in the Trainer. """
return self.trainer.logger if self.trainer else None

def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None):
batch = self.on_before_batch_transfer(batch)
batch = self.transfer_batch_to_device(batch, device)
batch = self.on_after_batch_transfer(batch)
return batch

def print(self, *args, **kwargs) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down Expand Up @@ -1697,7 +1703,7 @@ def to_onnx(
)
input_sample = self.example_input_array

input_sample = self.transfer_batch_to_device(input_sample)
input_sample = self._apply_batch_transfer_handler(input_sample)

if "example_outputs" not in kwargs:
self.eval()
Expand Down Expand Up @@ -1768,18 +1774,15 @@ def to_torchscript(
if self.example_input_array is None:
raise ValueError(
'Choosing method=`trace` requires either `example_inputs`'
' or `model.example_input_array` to be defined'
' or `model.example_input_array` to be defined.'
)
example_inputs = self.example_input_array

# automatically send example inputs to the right device and use trace
example_inputs = self.transfer_batch_to_device(example_inputs)
example_inputs = self._apply_batch_transfer_handler(example_inputs)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError(
"The 'method' parameter only supports 'script' or 'trace',"
f" but value given was: {method}"
)
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")

self.train(mode)

Expand Down
Loading