diff --git a/CHANGELOG.md b/CHANGELOG.md index ef33b0b06890a..1c7d5b128449c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579)) +- Added `on_before_batch_transfer` and `on_after_batch_transfer` data hooks ([#3671](https://github.com/PyTorchLightning/pytorch-lightning/pull/3671)) + + - Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 943525902f41b..43e58662c3b5b 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1209,3 +1209,15 @@ transfer_batch_to_device .. automethod:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device :noindex: + +on_before_batch_transfer +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_before_batch_transfer + :noindex: + +on_after_batch_transfer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer + :noindex: diff --git a/docs/source/conf.py b/docs/source/conf.py index 61c5c39361e37..813d5ee978821 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -374,7 +374,7 @@ def package_list_from_file(file): import torch from torch import nn import pytorch_lightning as pl -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities import ( _NATIVE_AMP_AVAILABLE, _APEX_AVAILABLE, diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 443cd5be4204b..01d69da76804e 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -164,6 +164,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) + .. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. @@ -179,6 +180,7 @@ To define a DataModule define 5 methods: - val_dataloader(s) - test_dataloader(s) + prepare_data ^^^^^^^^^^^^ Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed @@ -196,7 +198,9 @@ settings. MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) -.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`). + +.. warning:: ``prepare_data`` is called from a single GPU. Do not use it to assign state (``self.x = y``). + setup ^^^^^ @@ -269,6 +273,7 @@ Use this method to generate the val dataloader. Usually you just wrap the datas def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=64) + .. _datamodule-test-dataloader-label: test_dataloader @@ -284,24 +289,59 @@ Use this method to generate the test dataloader. Usually you just wrap the datas def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=64) + transfer_batch_to_device ^^^^^^^^^^^^^^^^^^^^^^^^ -Override to define how you want to move an arbitrary batch to a device +Override to define how you want to move an arbitrary batch to a device. -.. code-block:: python +.. testcode:: - import pytorch_lightning as pl - - - class MNISTDataModule(pl.LightningDataModule): + class MNISTDataModule(LightningDataModule): def transfer_batch_to_device(self, batch, device): x = batch['x'] x = CustomDataWrapper(x) - batch['x'].to(device) + batch['x'] = x.to(device) return batch -.. note:: To decouple your data from transforms you can parametrize them via `__init__`. +.. note:: This hook only runs on single GPU training and DDP (no data-parallel). + + +on_before_batch_transfer +^^^^^^^^^^^^^^^^^^^^^^^^ +Override to alter or apply augmentations to your batch before it is transferred to the device. + +.. testcode:: + + class MNISTDataModule(LightningDataModule): + def on_before_batch_transfer(self, batch): + batch['x'] = transforms(batch['x']) + return batch + + +.. note:: This hook only runs on single GPU training and DDP (no data-parallel). + + +on_after_batch_transfer +^^^^^^^^^^^^^^^^^^^^^^^ +Override to alter or apply augmentations to your batch after it is transferred to the device. + +.. testcode:: + + class MNISTDataModule(LightningDataModule): + def on_after_batch_transfer(self, batch): + batch['x'] = gpu_transforms(batch['x']) + return batch + + +.. note:: + This hook only runs on single GPU training and DDP (no data-parallel). This hook + will also be called when using CPU device, so adding augmentations here or in + ``on_before_batch_transfer`` means the same thing. + + + +.. note:: To decouple your data from transforms you can parametrize them via ``__init__``. .. code-block:: python diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2e8e31139dda2..dfc2aab625c09 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -105,7 +105,7 @@ def teardown(self): """ pass - def batch_to_device(self, batch: Any, device: torch.device) -> Any: + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. @@ -114,8 +114,10 @@ def batch_to_device(self, batch: Any, device: torch.device) -> Any: device: The target device """ model = self.lightning_module + if model is not None: - return model.transfer_batch_to_device(batch, device) + return model._apply_batch_transfer_handler(batch, device) + return move_data_to_device(batch, device) def on_train_start(self): @@ -135,9 +137,7 @@ def training_step(self, args): :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. """ - batch = self.to_device(args[0]) - - args[0] = batch + args[0] = self.to_device(args[0]) with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context(): return self.training_type_plugin.training_step(*args) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 3b9d8e7de49e1..9ce05ce0b36cf 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -19,7 +19,6 @@ from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union -import torch from torch.utils.data import DataLoader, Dataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks @@ -95,7 +94,7 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn -class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper): +class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. @@ -248,26 +247,6 @@ def prepare_data(self, *args, **kwargs): def setup(self, stage: Optional[str] = None): pass - @abstractmethod - def train_dataloader(self, *args, **kwargs) -> DataLoader: - pass - - @abstractmethod - def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - pass - - @abstractmethod - def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - pass - - @abstractmethod - def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - pass - - @abstractmethod - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: - pass - @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: r"""Extends existing argparse by default `LightningDataModule` attributes.""" diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index e67b7c230e93c..895f004baa950 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -57,8 +57,7 @@ def auto_transfer_args(self, *args, **kwargs): if not isinstance(self, LightningModule): return fn(self, *args, **kwargs) - args = self.transfer_batch_to_device(args, self.device) - kwargs = self.transfer_batch_to_device(kwargs, self.device) + args, kwargs = self.transfer_batch_to_device((args, kwargs)) return fn(self, *args, **kwargs) return auto_transfer_args diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index ac7bb2a1d20e1..57ed5762528df 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -320,7 +320,7 @@ def on_after_backward(self): class DataHooks: - """Hooks to be used with LightningDataModule.""" + """Hooks to be used for data related stuff.""" def prepare_data(self) -> None: """ @@ -564,9 +564,27 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). + Note: + This hook should only transfer the data and not modify it, nor should it move the data to + any other device than the one passed in as argument (unless you know what you are doing). + + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + If you need multi-GPU support for your custom batch objects, you need to define your custom + :class:`~torch.nn.parallel.DistributedDataParallel` or + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + + Returns: + A reference to the data on the new device. + Example:: - def transfer_batch_to_device(self, batch, device) + def transfer_batch_to_device(self, batch, device): if isinstance(batch, CustomBatch): # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) @@ -575,29 +593,62 @@ def transfer_batch_to_device(self, batch, device) batch = super().transfer_batch_to_device(data, device) return batch + See Also: + - :meth:`move_data_to_device` + - :meth:`apply_to_collection` + """ + device = device or self.device + return move_data_to_device(batch, device) + + def on_before_batch_transfer(self, batch): + """ + Override to alter or apply batch augmentations to your batch before it is transferred to the device. + + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + Args: - batch: A batch of data that needs to be transferred to a new device. - device: The target device as defined in PyTorch. + batch: A batch of data that needs to be altered or augmented. Returns: - A reference to the data on the new device. + A batch of data - Note: - This hook should only transfer the data and not modify it, nor should it move the data to - any other device than the one passed in as argument (unless you know what you are doing). + Example:: + + def on_before_batch_transfer(self, batch): + batch['x'] = transforms(batch['x']) + return batch + + See Also: + - :meth:`on_after_batch_transfer` + - :meth:`transfer_batch_to_device` + """ + return batch + + def on_after_batch_transfer(self, batch): + """ + Override to alter or apply batch augmentations to your batch after it is transferred to the device. Note: This hook only runs on single GPU training and DDP (no data-parallel). - If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``, - you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or - override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. + + Args: + batch: A batch of data that needs to be altered or augmented. + + Returns: + A batch of data + + Example:: + + def on_after_batch_transfer(self, batch): + batch['x'] = gpu_transforms(batch['x']) + return batch See Also: - - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` + - :meth:`on_before_batch_transfer` + - :meth:`transfer_batch_to_device` """ - device = device or self.device - return move_data_to_device(batch, device) + return batch class CheckpointHooks: @@ -611,7 +662,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: Args: checkpoint: Loaded checkpoint - Example:: def on_load_checkpoint(self, checkpoint): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 58d045044d0b4..91d2b3565d193 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -179,6 +179,12 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None + def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None): + batch = self.on_before_batch_transfer(batch) + batch = self.transfer_batch_to_device(batch, device) + batch = self.on_after_batch_transfer(batch) + return batch + def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once. @@ -1697,7 +1703,7 @@ def to_onnx( ) input_sample = self.example_input_array - input_sample = self.transfer_batch_to_device(input_sample) + input_sample = self._apply_batch_transfer_handler(input_sample) if "example_outputs" not in kwargs: self.eval() @@ -1768,18 +1774,15 @@ def to_torchscript( if self.example_input_array is None: raise ValueError( 'Choosing method=`trace` requires either `example_inputs`' - ' or `model.example_input_array` to be defined' + ' or `model.example_input_array` to be defined.' ) example_inputs = self.example_input_array # automatically send example inputs to the right device and use trace - example_inputs = self.transfer_batch_to_device(example_inputs) + example_inputs = self._apply_batch_transfer_handler(example_inputs) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: - raise ValueError( - "The 'method' parameter only supports 'script' or 'trace'," - f" but value given was: {method}" - ) + raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}") self.train(mode) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index ce90e21e3528c..afb64535d1470 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -246,7 +246,7 @@ def _forward_example_input(self) -> None: trainer = self._model.trainer input_ = model.example_input_array - input_ = model.transfer_batch_to_device(input_, model.device) + input_ = model._apply_batch_transfer_handler(input_, model.device) if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: model.forward = torch.cuda.amp.autocast()(model.forward) @@ -432,5 +432,5 @@ def get_human_readable_count(number: int) -> str: index = num_groups - 1 if index < 1 or number >= 100: return f"{int(number):,d} {labels[index]}" - else: - return f"{number:,.1f} {labels[index]}" + + return f"{number:,.1f} {labels[index]}" diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 2f8c888eba963..0485868fa2ef1 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -211,7 +211,7 @@ def log_graph(self, model: LightningModule, input_array=None): input_array = model.example_input_array if input_array is not None: - input_array = model.transfer_batch_to_device(input_array, model.device) + input_array = model._apply_batch_transfer_handler(input_array) self.experiment.add_graph(model, input_array) else: rank_zero_warn( diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index e956172ba55c1..5734f0fd8aebc 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -155,14 +155,12 @@ def log_graph(self, model: LightningModule, input_array=None): input_array = model.example_input_array if input_array is not None: - self.experiment.add_graph( - model, model.transfer_batch_to_device(model.example_input_array, model.device) - ) + self.experiment.add_graph(model, model._apply_batch_transfer_handler(input_array)) else: rank_zero_warn( - 'Could not log computational graph since the' - ' `model.example_input_array` attribute is not set' - ' or `input_array` was not given', UserWarning + 'Could not log computational graph since neither the' + ' `model.example_input_array` attribute is set nor' + ' `input_array` was given', UserWarning ) @rank_zero_only @@ -197,15 +195,15 @@ def save_dir(self) -> Optional[str]: def name(self) -> str: if self._experiment is None: return self._name - else: - return self.experiment.name + + return self.experiment.name @property def version(self) -> int: if self._experiment is None: return self._version - else: - return self.experiment.version + + return self.experiment.version # Test tube experiments are not pickleable, so we need to override a few # methods to get DDP working. See diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 3c2f02013921b..6ff35aadc36a3 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -131,9 +131,13 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], st if is_overridden('predict_dataloader', datamodule): model.predict_dataloader = datamodule.predict_dataloader - # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule + # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule + if is_overridden('on_before_batch_transfer', datamodule): + model.on_before_batch_transfer = datamodule.on_before_batch_transfer if is_overridden('transfer_batch_to_device', datamodule): model.transfer_batch_to_device = datamodule.transfer_batch_to_device + if is_overridden('on_after_batch_transfer', datamodule): + model.on_after_batch_transfer = datamodule.on_after_batch_transfer self.trainer.datamodule = datamodule datamodule.trainer = self.trainer diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 8cf1f0a9d1ffb..e2f3b559073d2 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -423,7 +423,8 @@ def test_step_end(self, outputs): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) -def test_dm_transfer_batch_to_device(get_module_mock): +def test_dm_apply_batch_transfer_handler(get_module_mock): + expected_device = torch.device('cuda', 0) class CustomBatch: @@ -432,14 +433,30 @@ def __init__(self, data): self.targets = data[1] class CurrentTestDM(LightningDataModule): - - hook_called = False - - def transfer_batch_to_device(self, data, device): - self.hook_called = True - data.samples = data.samples.to(device) - data.targets = data.targets.to(device) - return data + rank = 0 + transfer_batch_to_device_hook_rank = None + on_before_batch_transfer_hook_rank = None + on_after_batch_transfer_hook_rank = None + + def on_before_batch_transfer(self, batch): + self.on_before_batch_transfer_hook_rank = self.rank + self.rank += 1 + batch.samples += 1 + return batch + + def on_after_batch_transfer(self, batch): + assert batch.samples.device == batch.targets.device == expected_device + self.on_after_batch_transfer_hook_rank = self.rank + self.rank += 1 + batch.targets *= 2 + return batch + + def transfer_batch_to_device(self, batch, device): + self.transfer_batch_to_device_hook_rank = self.rank + self.rank += 1 + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + return batch dm = CurrentTestDM() model = BoringModel() @@ -452,10 +469,18 @@ def transfer_batch_to_device(self, data, device): if is_overridden('transfer_batch_to_device', dm): model.transfer_batch_to_device = dm.transfer_batch_to_device - batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) - expected = torch.device('cuda', 0) - assert dm.hook_called - assert batch_gpu.samples.device == batch_gpu.targets.device == expected + model.on_before_batch_transfer = dm.on_before_batch_transfer + model.transfer_batch_to_device = dm.transfer_batch_to_device + model.on_after_batch_transfer = dm.on_after_batch_transfer + + batch_gpu = trainer.accelerator_backend.batch_to_device(batch, expected_device) + + assert dm.on_before_batch_transfer_hook_rank == 0 + assert dm.transfer_batch_to_device_hook_rank == 1 + assert dm.on_after_batch_transfer_hook_rank == 2 + assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device + assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32)) + assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) def test_dm_reload_dataloaders_every_epoch(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 057512be31af2..b258fb17e9015 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -145,7 +145,8 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) -def test_transfer_batch_hook(model_getter_mock): +def test_apply_batch_transfer_handler(model_getter_mock): + expected_device = torch.device('cuda', 0) class CustomBatch: @@ -154,28 +155,46 @@ def __init__(self, data): self.targets = data[1] class CurrentTestModel(BoringModel): - - hook_called = False - - def transfer_batch_to_device(self, data, device): - self.hook_called = True - if isinstance(data, CustomBatch): - data.samples = data.samples.to(device) - data.targets = data.targets.to(device) - else: - data = super().transfer_batch_to_device(data, device) - return data + rank = 0 + transfer_batch_to_device_hook_rank = None + on_before_batch_transfer_hook_rank = None + on_after_batch_transfer_hook_rank = None + + def on_before_batch_transfer(self, batch): + self.on_before_batch_transfer_hook_rank = self.rank + self.rank += 1 + batch.samples += 1 + return batch + + def on_after_batch_transfer(self, batch): + assert batch.samples.device == batch.targets.device == expected_device + self.on_after_batch_transfer_hook_rank = self.rank + self.rank += 1 + batch.targets *= 2 + return batch + + def transfer_batch_to_device(self, batch, device): + self.transfer_batch_to_device_hook_rank = self.rank + self.rank += 1 + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + return batch model = CurrentTestModel() batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer(gpus=1) # running .fit() would require us to implement custom data loaders, we mock the model reference instead + model_getter_mock.return_value = model - batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) - expected = torch.device('cuda', 0) - assert model.hook_called - assert batch_gpu.samples.device == batch_gpu.targets.device == expected + batch_gpu = trainer.accelerator_backend.batch_to_device(batch, expected_device) + + assert model.on_before_batch_transfer_hook_rank == 0 + assert model.transfer_batch_to_device_hook_rank == 1 + assert model.on_after_batch_transfer_hook_rank == 2 + assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device + assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32)) + assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")