@@ -336,7 +336,7 @@ def on_post_move_to_device(self):
336336
337337
338338class DataHooks :
339- """Hooks to be used with LightningDataModule ."""
339+ """Hooks to be used for data related stuff ."""
340340
341341 def prepare_data (self ) -> None :
342342 """
@@ -580,9 +580,27 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =
580580
581581 For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
582582
583+ Note:
584+ This hook should only transfer the data and not modify it, nor should it move the data to
585+ any other device than the one passed in as argument (unless you know what you are doing).
586+
587+ Note:
588+ This hook only runs on single GPU training and DDP (no data-parallel).
589+ If you need multi-GPU support for your custom batch objects, you need to define your custom
590+ :class:`~torch.nn.parallel.DistributedDataParallel` or
591+ :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
592+ override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
593+
594+ Args:
595+ batch: A batch of data that needs to be transferred to a new device.
596+ device: The target device as defined in PyTorch.
597+
598+ Returns:
599+ A reference to the data on the new device.
600+
583601 Example::
584602
585- def transfer_batch_to_device(self, batch, device)
603+ def transfer_batch_to_device(self, batch, device):
586604 if isinstance(batch, CustomBatch):
587605 # move all tensors in your custom data structure to the device
588606 batch.samples = batch.samples.to(device)
@@ -591,29 +609,62 @@ def transfer_batch_to_device(self, batch, device)
591609 batch = super().transfer_batch_to_device(data, device)
592610 return batch
593611
612+ See Also:
613+ - :meth:`move_data_to_device`
614+ - :meth:`apply_to_collection`
615+ """
616+ device = device or self .device
617+ return move_data_to_device (batch , device )
618+
619+ def on_before_batch_transfer (self , batch ):
620+ """
621+ Override to alter or apply batch augmentations to your batch before it is transferred to the device.
622+
623+ Note:
624+ This hook only runs on single GPU training and DDP (no data-parallel).
625+
594626 Args:
595- batch: A batch of data that needs to be transferred to a new device.
596- device: The target device as defined in PyTorch.
627+ batch: A batch of data that needs to be altered or augmented.
597628
598629 Returns:
599- A reference to the data on the new device.
630+ A batch of data
600631
601- Note:
602- This hook should only transfer the data and not modify it, nor should it move the data to
603- any other device than the one passed in as argument (unless you know what you are doing).
632+ Example::
633+
634+ def on_before_batch_transfer(self, batch):
635+ batch['x'] = transforms(batch['x'])
636+ return batch
637+
638+ See Also:
639+ - :meth:`on_after_batch_transfer`
640+ - :meth:`transfer_batch_to_device`
641+ """
642+ return batch
643+
644+ def on_after_batch_transfer (self , batch ):
645+ """
646+ Override to alter or apply batch augmentations to your batch after it is transferred to the device.
604647
605648 Note:
606649 This hook only runs on single GPU training and DDP (no data-parallel).
607- If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``,
608- you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or
609- override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
650+
651+ Args:
652+ batch: A batch of data that needs to be altered or augmented.
653+
654+ Returns:
655+ A batch of data
656+
657+ Example::
658+
659+ def on_after_batch_transfer(self, batch):
660+ batch['x'] = gpu_transforms(batch['x'])
661+ return batch
610662
611663 See Also:
612- - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device `
613- - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection `
664+ - :meth:`on_before_batch_transfer `
665+ - :meth:`transfer_batch_to_device `
614666 """
615- device = device or self .device
616- return move_data_to_device (batch , device )
667+ return batch
617668
618669
619670class CheckpointHooks :
@@ -627,7 +678,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
627678 Args:
628679 checkpoint: Loaded checkpoint
629680
630-
631681 Example::
632682
633683 def on_load_checkpoint(self, checkpoint):
0 commit comments