From 3aef3471143ebcdeb9adc5a9e38354e45fdd70e7 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 21 Nov 2020 19:38:12 +0530 Subject: [PATCH 01/32] add hooks --- pytorch_lightning/core/datamodule.py | 22 +++++-------------- pytorch_lightning/core/decorators.py | 4 ++-- pytorch_lightning/core/hooks.py | 14 +++++++++++- pytorch_lightning/core/lightning.py | 8 +++---- pytorch_lightning/core/memory.py | 2 +- pytorch_lightning/loggers/tensorboard.py | 2 +- pytorch_lightning/loggers/test_tube.py | 8 +++---- .../trainer/connectors/data_connector.py | 4 ++++ 8 files changed, 34 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index d0e1725b2c4ac..8f1273bf4d6b3 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -95,7 +95,11 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn -class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper): +class LightningDataModule( + CheckpointHooks, + DataHooks, + metaclass=_DataModuleWrapper +): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. @@ -248,22 +252,6 @@ def prepare_data(self, *args, **kwargs): def setup(self, stage: Optional[str] = None): pass - @abstractmethod - def train_dataloader(self, *args, **kwargs) -> DataLoader: - pass - - @abstractmethod - def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - pass - - @abstractmethod - def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - pass - - @abstractmethod - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: - pass - @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: r"""Extends existing argparse by default `LightningDataModule` attributes.""" diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index e67b7c230e93c..36f3cacc162b0 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -57,8 +57,8 @@ def auto_transfer_args(self, *args, **kwargs): if not isinstance(self, LightningModule): return fn(self, *args, **kwargs) - args = self.transfer_batch_to_device(args, self.device) - kwargs = self.transfer_batch_to_device(kwargs, self.device) + args = self.transfer_batch_to_device(args) + kwargs = self.transfer_batch_to_device(kwargs) return fn(self, *args, **kwargs) return auto_transfer_args diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 23fd5d9b58755..a823bc831f660 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -314,7 +314,7 @@ def on_after_backward(self): class DataHooks: - """Hooks to be used with LightningDataModule.""" + """Hooks to be used for data related stuff.""" def prepare_data(self) -> None: """ @@ -568,6 +568,18 @@ def transfer_batch_to_device(self, batch, device) device = device or self.device return move_data_to_device(batch, device) + def on_before_batch_transfer(self, batch): + return batch + + def on_after_batch_transfer(self, batch): + return batch + + def prepare_batch_for_transfer(self, batch: Any, device: Optional[torch.device] = None): + batch = self.on_before_batch_transfer(batch) + batch = self.transfer_batch_to_device(batch, device) + batch = self.on_after_batch_transfer(batch) + return batch + class CheckpointHooks: """Hooks to be used with Checkpointing.""" diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9c87836b4415a..c81d7c3247f56 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -48,12 +48,12 @@ class LightningModule( ABC, + CheckpointHooks, + DataHooks, DeviceDtypeModuleMixin, GradInformation, - ModelIO, ModelHooks, - DataHooks, - CheckpointHooks, + ModelIO, Module, ): # Below is for property support of JIT in PyTorch 1.7 @@ -1773,7 +1773,7 @@ def to_torchscript( example_inputs = self.example_input_array # automatically send example inputs to the right device and use trace - example_inputs = self.transfer_batch_to_device(example_inputs) + example_inputs = self.prepare_batch_for_transfer(example_inputs) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: raise ValueError( diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index ce90e21e3528c..e0008a8fa60ab 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -246,7 +246,7 @@ def _forward_example_input(self) -> None: trainer = self._model.trainer input_ = model.example_input_array - input_ = model.transfer_batch_to_device(input_, model.device) + input_ = model.prepare_batch_for_transfer(input_, model.device) if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: model.forward = torch.cuda.amp.autocast()(model.forward) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 2f8c888eba963..2ef8e41f545e7 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -211,7 +211,7 @@ def log_graph(self, model: LightningModule, input_array=None): input_array = model.example_input_array if input_array is not None: - input_array = model.transfer_batch_to_device(input_array, model.device) + input_array = model.prepare_batch_for_transfer(input_array) self.experiment.add_graph(model, input_array) else: rank_zero_warn( diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index e956172ba55c1..45cbbd59b32f0 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -156,13 +156,13 @@ def log_graph(self, model: LightningModule, input_array=None): if input_array is not None: self.experiment.add_graph( - model, model.transfer_batch_to_device(model.example_input_array, model.device) + model, model._apply_batch_transfer_handler(input_array) ) else: rank_zero_warn( - 'Could not log computational graph since the' - ' `model.example_input_array` attribute is not set' - ' or `input_array` was not given', UserWarning + 'Could not log computational graph since neither the' + ' `model.example_input_array` attribute is set nor' + ' `input_array` was given', UserWarning ) @rank_zero_only diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9161f3e8754ec..ab57f162029a6 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -120,8 +120,12 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], st model.test_dataloader = datamodule.test_dataloader # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule + if is_overridden('on_before_batch_transfer', datamodule): + model.on_before_batch_transfer = datamodule.on_before_batch_transfer if is_overridden('transfer_batch_to_device', datamodule): model.transfer_batch_to_device = datamodule.transfer_batch_to_device + if is_overridden('on_after_batch_transfer', datamodule): + model.on_after_batch_transfer = datamodule.on_after_batch_transfer self.trainer.datamodule = datamodule datamodule.trainer = self.trainer From d0e949c3b1e64d0f9dccb182b1633fe172d3baf2 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 21 Nov 2020 19:54:22 +0530 Subject: [PATCH 02/32] comment --- pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index ab57f162029a6..5c69defe21f9c 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -119,7 +119,7 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], st if is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader - # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule + # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule if is_overridden('on_before_batch_transfer', datamodule): model.on_before_batch_transfer = datamodule.on_before_batch_transfer if is_overridden('transfer_batch_to_device', datamodule): From 6b98fdd43b6692a25c638935fd543681de1900f0 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 21 Nov 2020 20:57:25 +0530 Subject: [PATCH 03/32] docs --- pytorch_lightning/core/hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a823bc831f660..c2359291f8e87 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -535,7 +535,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = Example:: - def transfer_batch_to_device(self, batch, device) + def transfer_batch_to_device(self, batch, device): if isinstance(batch, CustomBatch): # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) From 48e724c004298766f2bb24d776c58ad4acff88eb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 21 Nov 2020 22:06:49 +0530 Subject: [PATCH 04/32] add tests --- tests/core/test_datamodules.py | 63 ++++++++++++++++++++++++++++------ tests/models/test_hooks.py | 57 +++++++++++++++++++++++------- 2 files changed, 97 insertions(+), 23 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 8cf1f0a9d1ffb..45676b5811cd3 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -423,7 +423,7 @@ def test_step_end(self, outputs): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) -def test_dm_transfer_batch_to_device(get_module_mock): +def test_dm_apply_batch_transfer_handler(get_module_mock): class CustomBatch: @@ -432,14 +432,44 @@ def __init__(self, data): self.targets = data[1] class CurrentTestDM(LightningDataModule): + rank = 0 + transfer_batch_to_device_hook_rank = None + on_before_batch_transfer_hook_rank = None + on_after_batch_transfer_hook_rank = None - hook_called = False + def on_before_batch_transfer(self, batch): + self.on_before_batch_transfer_hook_rank = self.rank + self.rank += 1 - def transfer_batch_to_device(self, data, device): - self.hook_called = True - data.samples = data.samples.to(device) - data.targets = data.targets.to(device) - return data + if isinstance(batch, CustomBatch): + batch.samples += 1 + else: + batch = super().on_before_batch_transfer(batch) + + return batch + + def on_after_batch_transfer(self, batch): + self.on_after_batch_transfer_hook_rank = self.rank + self.rank += 1 + + if isinstance(batch, CustomBatch): + batch.targets *= 2 + else: + batch = super().on_after_batch_transfer(batch) + + return batch + + def transfer_batch_to_device(self, batch, device): + self.transfer_batch_to_device_hook_rank = self.rank + self.rank += 1 + + if isinstance(batch, CustomBatch): + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + else: + batch = super().transfer_batch_to_device(batch, device) + + return batch dm = CurrentTestDM() model = BoringModel() @@ -452,10 +482,23 @@ def transfer_batch_to_device(self, data, device): if is_overridden('transfer_batch_to_device', dm): model.transfer_batch_to_device = dm.transfer_batch_to_device + if is_overridden('on_before_batch_transfer', dm): + model.on_before_batch_transfer = dm.on_before_batch_transfer + if is_overridden('transfer_batch_to_device', dm): + model.transfer_batch_to_device = dm.transfer_batch_to_device + if is_overridden('on_after_batch_transfer', dm): + model.on_after_batch_transfer = dm.on_after_batch_transfer + + trainer.accelerator_backend = GPUAccelerator(trainer) batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) - expected = torch.device('cuda', 0) - assert dm.hook_called - assert batch_gpu.samples.device == batch_gpu.targets.device == expected + expected_device = torch.device('cuda', 0) + + assert dm.on_before_batch_transfer_hook_rank == 0 + assert dm.transfer_batch_to_device_hook_rank == 1 + assert dm.on_after_batch_transfer_hook_rank == 2 + assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device + assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 28)) + assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1) * 2) def test_dm_reload_dataloaders_every_epoch(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 057512be31af2..27687aa958806 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -145,26 +145,52 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) -def test_transfer_batch_hook(model_getter_mock): +def test_apply_batch_transfer_handler(model_getter_mock): class CustomBatch: - def __init__(self, data): self.samples = data[0] self.targets = data[1] class CurrentTestModel(BoringModel): + rank = 0 + transfer_batch_to_device_hook_rank = None + on_before_batch_transfer_hook_rank = None + on_after_batch_transfer_hook_rank = None + + def on_before_batch_transfer(self, batch): + self.on_before_batch_transfer_hook_rank = self.rank + self.rank += 1 + + if isinstance(batch, CustomBatch): + batch.samples += 1 + else: + batch = super().on_before_batch_transfer(batch) + + return batch - hook_called = False + def on_after_batch_transfer(self, batch): + self.on_after_batch_transfer_hook_rank = self.rank + self.rank += 1 - def transfer_batch_to_device(self, data, device): - self.hook_called = True - if isinstance(data, CustomBatch): - data.samples = data.samples.to(device) - data.targets = data.targets.to(device) + if isinstance(batch, CustomBatch): + batch.targets *= 2 else: - data = super().transfer_batch_to_device(data, device) - return data + batch = super().on_after_batch_transfer(batch) + + return batch + + def transfer_batch_to_device(self, batch, device): + self.transfer_batch_to_device_hook_rank = self.rank + self.rank += 1 + + if isinstance(batch, CustomBatch): + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + else: + batch = super().transfer_batch_to_device(batch, device) + + return batch model = CurrentTestModel() batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) @@ -173,9 +199,14 @@ def transfer_batch_to_device(self, data, device): # running .fit() would require us to implement custom data loaders, we mock the model reference instead model_getter_mock.return_value = model batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) - expected = torch.device('cuda', 0) - assert model.hook_called - assert batch_gpu.samples.device == batch_gpu.targets.device == expected + expected_device = torch.device('cuda', 0) + + assert model.on_before_batch_transfer_hook_rank == 0 + assert model.transfer_batch_to_device_hook_rank == 1 + assert model.on_after_batch_transfer_hook_rank == 2 + assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device + assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 28)) + assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1) * 2) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From 28644298638c6a2f51a66287de99e6214400762d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 21 Nov 2020 22:07:03 +0530 Subject: [PATCH 05/32] make it private --- pytorch_lightning/core/hooks.py | 12 ++++++------ pytorch_lightning/core/lightning.py | 22 ++++++++++++++++++++++ pytorch_lightning/core/memory.py | 2 +- pytorch_lightning/loggers/tensorboard.py | 2 +- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index c2359291f8e87..e78dc94089be6 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -569,15 +569,15 @@ def transfer_batch_to_device(self, batch, device): return move_data_to_device(batch, device) def on_before_batch_transfer(self, batch): + """ + Called before batch is transfered to the device + """ return batch def on_after_batch_transfer(self, batch): - return batch - - def prepare_batch_for_transfer(self, batch: Any, device: Optional[torch.device] = None): - batch = self.on_before_batch_transfer(batch) - batch = self.transfer_batch_to_device(batch, device) - batch = self.on_after_batch_transfer(batch) + """ + Called after batch is transfered to the device + """ return batch diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c81d7c3247f56..f09808bba698a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -179,6 +179,12 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None + def _prepare_batch_for_transfer(self, batch: Any, device: Optional[torch.device] = None): + batch = self.on_before_batch_transfer(batch) + batch = self.transfer_batch_to_device(batch, device) + batch = self.on_after_batch_transfer(batch) + return batch + def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once. @@ -1760,6 +1766,7 @@ def to_torchscript( """ mode = self.training +<<<<<<< HEAD if method == 'script': torchscript_module = torch.jit.script(self.eval(), **kwargs) elif method == 'trace': @@ -1781,6 +1788,21 @@ def to_torchscript( f" but value given was: {method}" ) +======= + with torch.no_grad(): + if method == 'script': + torchscript_module = torch.jit.script(self.eval(), **kwargs) + elif method == 'trace': + # if no example inputs are provided, try to see if model has example_input_array set + if example_inputs is None: + example_inputs = self.example_input_array + # automatically send example inputs to the right device and use trace + example_inputs = self._prepare_batch_for_transfer(example_inputs) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) + else: + raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" + f"{method}") +>>>>>>> make it private self.train(mode) if file_path is not None: diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index e0008a8fa60ab..fdb070f0f0348 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -246,7 +246,7 @@ def _forward_example_input(self) -> None: trainer = self._model.trainer input_ = model.example_input_array - input_ = model.prepare_batch_for_transfer(input_, model.device) + input_ = model._prepare_batch_for_transfer(input_, model.device) if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: model.forward = torch.cuda.amp.autocast()(model.forward) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 2ef8e41f545e7..69a0d2c28e707 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -211,7 +211,7 @@ def log_graph(self, model: LightningModule, input_array=None): input_array = model.example_input_array if input_array is not None: - input_array = model.prepare_batch_for_transfer(input_array) + input_array = model._prepare_batch_for_transfer(input_array) self.experiment.add_graph(model, input_array) else: rank_zero_warn( From bbb6bb62cb7d3e5068a52a4405569519f310ae88 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 21 Nov 2020 22:50:43 +0530 Subject: [PATCH 06/32] fix tests --- tests/core/test_datamodules.py | 2 +- tests/models/test_hooks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 45676b5811cd3..d7ad389bc74d7 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -498,7 +498,7 @@ def transfer_batch_to_device(self, batch, device): assert dm.on_after_batch_transfer_hook_rank == 2 assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 28)) - assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1) * 2) + assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) def test_dm_reload_dataloaders_every_epoch(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 27687aa958806..6e6a18ad24c88 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -206,7 +206,7 @@ def transfer_batch_to_device(self, batch, device): assert model.on_after_batch_transfer_hook_rank == 2 assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 28)) - assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1) * 2) + assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From 5e7ef6b8906ee40e2b750b10f76a89ef9e9aff7a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 22 Nov 2020 00:17:35 +0530 Subject: [PATCH 07/32] docs --- docs/source/common/lightning_module.rst | 12 ++++++++++ docs/source/extensions/datamodules.rst | 30 ++++++++++++++++++++++++- pytorch_lightning/core/hooks.py | 4 ++-- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 943525902f41b..43e58662c3b5b 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1209,3 +1209,15 @@ transfer_batch_to_device .. automethod:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device :noindex: + +on_before_batch_transfer +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_before_batch_transfer + :noindex: + +on_after_batch_transfer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer + :noindex: diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 443cd5be4204b..51f34ff038bbf 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -286,7 +286,7 @@ Use this method to generate the test dataloader. Usually you just wrap the datas transfer_batch_to_device ^^^^^^^^^^^^^^^^^^^^^^^^ -Override to define how you want to move an arbitrary batch to a device +Override to define how you want to move an arbitrary batch to a device. .. code-block:: python @@ -300,6 +300,34 @@ Override to define how you want to move an arbitrary batch to a device batch['x'].to(device) return batch +on_before_batch_transfer +^^^^^^^^^^^^^^^^^^^^^^^^ +Override to alter or apply batch augmentations to your batch before it is transferred to the device. + +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def on_before_batch_transfer(self, batch): + batch['x'] = transforms(batch['x']) + return batch + +on_after_batch_transfer +^^^^^^^^^^^^^^^^^^^^^^^ +Override to alter or apply batch augmentations to your batch after it is transferred to the device. + +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def on_after_batch_transfer(self, batch): + batch['x'] = gpu_transforms(batch['x']) + return batch + .. note:: To decouple your data from transforms you can parametrize them via `__init__`. diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e78dc94089be6..b619dbc2fabb7 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -570,13 +570,13 @@ def transfer_batch_to_device(self, batch, device): def on_before_batch_transfer(self, batch): """ - Called before batch is transfered to the device + Called before batch is transferred to the device """ return batch def on_after_batch_transfer(self, batch): """ - Called after batch is transfered to the device + Called after batch is transferred to the device """ return batch From 4fde09245d1b51bccce5e6a9e9eaeb010a9125e6 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 22 Nov 2020 00:19:33 +0530 Subject: [PATCH 08/32] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ef3267812a80..e06ce60cd3861 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579)) +- Added `on_before_batch_transfer` and `on_after_batch_transfer` data hooks ([#3671](https://github.com/PyTorchLightning/pytorch-lightning/pull/3671)) + + - Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479)) From f4f5a27d5fbabba6fe62e1dfe0d656906c16ad9c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 1 Dec 2020 02:17:50 +0530 Subject: [PATCH 09/32] testcode --- docs/source/extensions/datamodules.rst | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 51f34ff038bbf..472922bab2920 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -288,10 +288,7 @@ transfer_batch_to_device ^^^^^^^^^^^^^^^^^^^^^^^^ Override to define how you want to move an arbitrary batch to a device. -.. code-block:: python - - import pytorch_lightning as pl - +.. testcode:: class MNISTDataModule(pl.LightningDataModule): def transfer_batch_to_device(self, batch, device): @@ -304,10 +301,7 @@ on_before_batch_transfer ^^^^^^^^^^^^^^^^^^^^^^^^ Override to alter or apply batch augmentations to your batch before it is transferred to the device. -.. code-block:: python - - import pytorch_lightning as pl - +.. testcode:: class MNISTDataModule(pl.LightningDataModule): def on_before_batch_transfer(self, batch): @@ -318,10 +312,7 @@ on_after_batch_transfer ^^^^^^^^^^^^^^^^^^^^^^^ Override to alter or apply batch augmentations to your batch after it is transferred to the device. -.. code-block:: python - - import pytorch_lightning as pl - +.. testcode:: class MNISTDataModule(pl.LightningDataModule): def on_after_batch_transfer(self, batch): From 4df6f0854334de719f30e31a671332226c453e3a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 3 Dec 2020 23:27:01 +0530 Subject: [PATCH 10/32] codefactor --- pytorch_lightning/core/memory.py | 4 ++-- pytorch_lightning/loggers/test_tube.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index fdb070f0f0348..a663fe43b976f 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -432,5 +432,5 @@ def get_human_readable_count(number: int) -> str: index = num_groups - 1 if index < 1 or number >= 100: return f"{int(number):,d} {labels[index]}" - else: - return f"{number:,.1f} {labels[index]}" + + return f"{number:,.1f} {labels[index]}" diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 45cbbd59b32f0..19484d96fcef2 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -197,15 +197,15 @@ def save_dir(self) -> Optional[str]: def name(self) -> str: if self._experiment is None: return self._name - else: - return self.experiment.name + + return self.experiment.name @property def version(self) -> int: if self._experiment is None: return self._version - else: - return self.experiment.version + + return self.experiment.version # Test tube experiments are not pickleable, so we need to override a few # methods to get DDP working. See From 7252acf76c5e09480ab0d7180488e117075e6970 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 12 Dec 2020 01:31:21 +0530 Subject: [PATCH 11/32] fix doctest --- docs/source/extensions/datamodules.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 472922bab2920..b10b806ba86c1 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -290,6 +290,9 @@ Override to define how you want to move an arbitrary batch to a device. .. testcode:: + import pytorch_lightning as pl + + class MNISTDataModule(pl.LightningDataModule): def transfer_batch_to_device(self, batch, device): x = batch['x'] @@ -303,6 +306,9 @@ Override to alter or apply batch augmentations to your batch before it is transf .. testcode:: + import pytorch_lightning as pl + + class MNISTDataModule(pl.LightningDataModule): def on_before_batch_transfer(self, batch): batch['x'] = transforms(batch['x']) @@ -314,6 +320,9 @@ Override to alter or apply batch augmentations to your batch after it is transfe .. testcode:: + import pytorch_lightning as pl + + class MNISTDataModule(pl.LightningDataModule): def on_after_batch_transfer(self, batch): batch['x'] = gpu_transforms(batch['x']) From e7dafd077158fc1655b126bd5218b47c7317e225 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 12 Dec 2020 01:46:35 +0530 Subject: [PATCH 12/32] fix doctest --- docs/source/extensions/datamodules.rst | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index b10b806ba86c1..cfd8f3024f7f0 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -1,3 +1,7 @@ +.. testsetup:: * + + from pytorch_lightning.core.datamodule import LightningDataModule + .. _datamodules: LightningDataModule @@ -290,10 +294,7 @@ Override to define how you want to move an arbitrary batch to a device. .. testcode:: - import pytorch_lightning as pl - - - class MNISTDataModule(pl.LightningDataModule): + class MNISTDataModule(LightningDataModule): def transfer_batch_to_device(self, batch, device): x = batch['x'] x = CustomDataWrapper(x) @@ -306,10 +307,7 @@ Override to alter or apply batch augmentations to your batch before it is transf .. testcode:: - import pytorch_lightning as pl - - - class MNISTDataModule(pl.LightningDataModule): + class MNISTDataModule(LightningDataModule): def on_before_batch_transfer(self, batch): batch['x'] = transforms(batch['x']) return batch @@ -320,10 +318,7 @@ Override to alter or apply batch augmentations to your batch after it is transfe .. testcode:: - import pytorch_lightning as pl - - - class MNISTDataModule(pl.LightningDataModule): + class MNISTDataModule(LightningDataModule): def on_after_batch_transfer(self, batch): batch['x'] = gpu_transforms(batch['x']) return batch From d9fc2afbd4f8fc985d7b172f30a5e18a8c15bd42 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 12 Dec 2020 03:52:11 +0530 Subject: [PATCH 13/32] suggestions --- docs/source/extensions/datamodules.rst | 6 +----- tests/core/test_datamodules.py | 24 +++++------------------- tests/models/test_hooks.py | 23 ++++------------------- 3 files changed, 10 insertions(+), 43 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index cfd8f3024f7f0..13160c3b425ad 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -1,7 +1,3 @@ -.. testsetup:: * - - from pytorch_lightning.core.datamodule import LightningDataModule - .. _datamodules: LightningDataModule @@ -298,7 +294,7 @@ Override to define how you want to move an arbitrary batch to a device. def transfer_batch_to_device(self, batch, device): x = batch['x'] x = CustomDataWrapper(x) - batch['x'].to(device) + x.to(device) return batch on_before_batch_transfer diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index d7ad389bc74d7..24948d41403ca 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -23,6 +23,7 @@ from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel @@ -440,35 +441,20 @@ class CurrentTestDM(LightningDataModule): def on_before_batch_transfer(self, batch): self.on_before_batch_transfer_hook_rank = self.rank self.rank += 1 - - if isinstance(batch, CustomBatch): - batch.samples += 1 - else: - batch = super().on_before_batch_transfer(batch) - + batch.samples += 1 return batch def on_after_batch_transfer(self, batch): self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 - - if isinstance(batch, CustomBatch): - batch.targets *= 2 - else: - batch = super().on_after_batch_transfer(batch) - + batch.targets *= 2 return batch def transfer_batch_to_device(self, batch, device): self.transfer_batch_to_device_hook_rank = self.rank self.rank += 1 - - if isinstance(batch, CustomBatch): - batch.samples = batch.samples.to(device) - batch.targets = batch.targets.to(device) - else: - batch = super().transfer_batch_to_device(batch, device) - + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) return batch dm = CurrentTestDM() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 6e6a18ad24c88..6cc44ca28471d 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -161,35 +161,20 @@ class CurrentTestModel(BoringModel): def on_before_batch_transfer(self, batch): self.on_before_batch_transfer_hook_rank = self.rank self.rank += 1 - - if isinstance(batch, CustomBatch): - batch.samples += 1 - else: - batch = super().on_before_batch_transfer(batch) - + batch.samples += 1 return batch def on_after_batch_transfer(self, batch): self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 - - if isinstance(batch, CustomBatch): - batch.targets *= 2 - else: - batch = super().on_after_batch_transfer(batch) - + batch.targets *= 2 return batch def transfer_batch_to_device(self, batch, device): self.transfer_batch_to_device_hook_rank = self.rank self.rank += 1 - - if isinstance(batch, CustomBatch): - batch.samples = batch.samples.to(device) - batch.targets = batch.targets.to(device) - else: - batch = super().transfer_batch_to_device(batch, device) - + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) return batch model = CurrentTestModel() From 69437cf2cdea05acb8cfe71b1bf2069ba5cbdb21 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 12 Dec 2020 04:07:17 +0530 Subject: [PATCH 14/32] is always overriden --- tests/core/test_datamodules.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 24948d41403ca..c34c034b9c15a 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -468,12 +468,9 @@ def transfer_batch_to_device(self, batch, device): if is_overridden('transfer_batch_to_device', dm): model.transfer_batch_to_device = dm.transfer_batch_to_device - if is_overridden('on_before_batch_transfer', dm): - model.on_before_batch_transfer = dm.on_before_batch_transfer - if is_overridden('transfer_batch_to_device', dm): - model.transfer_batch_to_device = dm.transfer_batch_to_device - if is_overridden('on_after_batch_transfer', dm): - model.on_after_batch_transfer = dm.on_after_batch_transfer + model.on_before_batch_transfer = dm.on_before_batch_transfer + model.transfer_batch_to_device = dm.transfer_batch_to_device + model.on_after_batch_transfer = dm.on_after_batch_transfer trainer.accelerator_backend = GPUAccelerator(trainer) batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) From 91be6ca44145e0825d45bc9c09eadfa25a052759 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 22 Dec 2020 01:21:09 +0530 Subject: [PATCH 15/32] pep and BoringModel --- tests/core/test_datamodules.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c34c034b9c15a..733fd3e619915 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -23,7 +23,6 @@ from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel @@ -457,9 +456,14 @@ def transfer_batch_to_device(self, batch, device): batch.targets = batch.targets.to(device) return batch +<<<<<<< HEAD dm = CurrentTestDM() model = BoringModel() +======= + model = BoringModel() + dm = CurrentTestDM() +>>>>>>> pep and BoringModel batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer(gpus=1) From c6ca4ad3c112bc48f56b27b9ed8570e3bd88b65b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 22 Dec 2020 02:04:23 +0530 Subject: [PATCH 16/32] BoringModel --- tests/core/test_datamodules.py | 2 +- tests/models/test_hooks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 733fd3e619915..cb223ab0a79d9 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -484,7 +484,7 @@ def transfer_batch_to_device(self, batch, device): assert dm.transfer_batch_to_device_hook_rank == 1 assert dm.on_after_batch_transfer_hook_rank == 2 assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device - assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 28)) + assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32)) assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 6cc44ca28471d..1abd96be72d2e 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -190,7 +190,7 @@ def transfer_batch_to_device(self, batch, device): assert model.transfer_batch_to_device_hook_rank == 1 assert model.on_after_batch_transfer_hook_rank == 2 assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device - assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 28)) + assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32)) assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) From 95599da0d612515cd10aa07172cb36354d5786d5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 29 Dec 2020 01:42:17 +0530 Subject: [PATCH 17/32] docs --- docs/source/extensions/datamodules.rst | 8 ++- pytorch_lightning/core/hooks.py | 76 +++++++++++++++++++------- 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 13160c3b425ad..d30fb139daa8b 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -288,6 +288,8 @@ transfer_batch_to_device ^^^^^^^^^^^^^^^^^^^^^^^^ Override to define how you want to move an arbitrary batch to a device. +.. note:: This hook only runs on single gpu training (no data-parallel) or with sharded plugin. + .. testcode:: class MNISTDataModule(LightningDataModule): @@ -301,6 +303,8 @@ on_before_batch_transfer ^^^^^^^^^^^^^^^^^^^^^^^^ Override to alter or apply batch augmentations to your batch before it is transferred to the device. +.. note:: This hook only runs on single gpu training (no data-parallel) or with sharded plugin. + .. testcode:: class MNISTDataModule(LightningDataModule): @@ -312,6 +316,8 @@ on_after_batch_transfer ^^^^^^^^^^^^^^^^^^^^^^^ Override to alter or apply batch augmentations to your batch after it is transferred to the device. +.. note:: This hook only runs on single gpu training (no data-parallel) or with sharded plugin. + .. testcode:: class MNISTDataModule(LightningDataModule): @@ -320,7 +326,7 @@ Override to alter or apply batch augmentations to your batch after it is transfe return batch -.. note:: To decouple your data from transforms you can parametrize them via `__init__`. +.. note:: To decouple your data from transforms you can parametrize them via ``__init__``. .. code-block:: python diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b619dbc2fabb7..ce54e1bcefeb7 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -533,16 +533,16 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). - Example:: + Note: + This hook should only transfer the data and not modify it, nor should it move the data to + any other device than the one passed in as argument (unless you know what you are doing). - def transfer_batch_to_device(self, batch, device): - if isinstance(batch, CustomBatch): - # move all tensors in your custom data structure to the device - batch.samples = batch.samples.to(device) - batch.targets = batch.targets.to(device) - else: - batch = super().transfer_batch_to_device(data, device) - return batch + Note: + This hook only runs on single GPU training (no data-parallel) or with sharded plugin. + If you need multi-GPU support for your custom batch objects, you need to define your custom + :class:`~torch.nn.parallel.DistributedDataParallel` or + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. Args: batch: A batch of data that needs to be transferred to a new device. @@ -551,15 +551,16 @@ def transfer_batch_to_device(self, batch, device): Returns: A reference to the data on the new device. - Note: - This hook should only transfer the data and not modify it, nor should it move the data to - any other device than the one passed in as argument (unless you know what you are doing). + Example:: - Note: - This hook only runs on single GPU training and DDP (no data-parallel). - If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``, - you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or - override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. + def transfer_batch_to_device(self, batch, device): + if isinstance(batch, CustomBatch): + # move all tensors in your custom data structure to the device + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + else: + batch = super().transfer_batch_to_device(data, device) + return batch See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` @@ -570,13 +571,50 @@ def transfer_batch_to_device(self, batch, device): def on_before_batch_transfer(self, batch): """ - Called before batch is transferred to the device + Override to alter or apply batch augmentations to your batch before it is transferred to the device. + + Note: + This hook only runs on single gpu training (no data-parallel) or with sharded plugin. + + Args: + batch: A batch of data that needs to be altered or augmented. + + Returns: + A batch of data + + Example:: + + def on_before_batch_transfer(self, batch): + batch['x'] = transforms(batch['x']) + return batch + + See Also: + - :func:`~pytorch_lightning.core.on_after_batch_transfer` + - :func:`~pytorch_lightning.core.transfer_batch_to_device` """ return batch def on_after_batch_transfer(self, batch): """ - Called after batch is transferred to the device + Override to alter or apply batch augmentations to your batch after it is transferred to the device. + + Note: + This hook only runs on single gpu training (no data-parallel) or with sharded plugin. + + Args: + batch: A batch of data that needs to be altered or augmented. + + Returns: + A batch of data + + Example:: + def on_after_batch_transfer(self, batch): + batch['x'] = gpu_transforms(batch['x']) + return batch + + See Also: + - :func:`~pytorch_lightning.core.on_before_batch_transfer` + - :func:`~pytorch_lightning.core.transfer_batch_to_device` """ return batch From 501cb0b750bd00a0ce5e5f623badb67a68fd286b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 29 Dec 2020 01:48:33 +0530 Subject: [PATCH 18/32] docs --- pytorch_lightning/core/hooks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index ce54e1bcefeb7..bb76d2c310537 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -529,7 +529,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = - :class:`list` - :class:`dict` - :class:`tuple` - - :class:`torchtext.data.batch.Batch` + - :class:`~torchtext.data.batch.Batch` For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). @@ -608,6 +608,7 @@ def on_after_batch_transfer(self, batch): A batch of data Example:: + def on_after_batch_transfer(self, batch): batch['x'] = gpu_transforms(batch['x']) return batch From 3b96ff78e04868d4983a0a7f4fafbfc67dfe3083 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 29 Dec 2020 02:00:28 +0530 Subject: [PATCH 19/32] docs --- pytorch_lightning/core/hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index bb76d2c310537..dbecacc09f247 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -529,7 +529,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = - :class:`list` - :class:`dict` - :class:`tuple` - - :class:`~torchtext.data.batch.Batch` + - :class:`torchtext.data.batch.Batch` For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). From 9fe7bef5d9de00eae6e3c5e4a782a1c551a16794 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 1 Jan 2021 20:35:37 +0530 Subject: [PATCH 20/32] fix --- docs/source/extensions/datamodules.rst | 2 -- pytorch_lightning/core/hooks.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index d30fb139daa8b..790d158c945d8 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -303,8 +303,6 @@ on_before_batch_transfer ^^^^^^^^^^^^^^^^^^^^^^^^ Override to alter or apply batch augmentations to your batch before it is transferred to the device. -.. note:: This hook only runs on single gpu training (no data-parallel) or with sharded plugin. - .. testcode:: class MNISTDataModule(LightningDataModule): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index dbecacc09f247..a9df7b0dc8065 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -573,9 +573,6 @@ def on_before_batch_transfer(self, batch): """ Override to alter or apply batch augmentations to your batch before it is transferred to the device. - Note: - This hook only runs on single gpu training (no data-parallel) or with sharded plugin. - Args: batch: A batch of data that needs to be altered or augmented. From b32d6d64e3fb88373cb7e49e64afded9cbee258a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 27 Jan 2021 00:37:43 +0530 Subject: [PATCH 21/32] rebase --- docs/source/extensions/datamodules.rst | 6 +++-- pytorch_lightning/core/hooks.py | 7 +++-- pytorch_lightning/core/lightning.py | 36 ++++++++------------------ tests/core/test_datamodules.py | 5 ---- 4 files changed, 20 insertions(+), 34 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 790d158c945d8..a9e53b5fce34e 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -288,7 +288,7 @@ transfer_batch_to_device ^^^^^^^^^^^^^^^^^^^^^^^^ Override to define how you want to move an arbitrary batch to a device. -.. note:: This hook only runs on single gpu training (no data-parallel) or with sharded plugin. +.. note:: This hook only runs on single GPU training and DDP (no data-parallel). .. testcode:: @@ -303,6 +303,8 @@ on_before_batch_transfer ^^^^^^^^^^^^^^^^^^^^^^^^ Override to alter or apply batch augmentations to your batch before it is transferred to the device. +.. note:: This hook only runs on single GPU training and DDP (no data-parallel). + .. testcode:: class MNISTDataModule(LightningDataModule): @@ -314,7 +316,7 @@ on_after_batch_transfer ^^^^^^^^^^^^^^^^^^^^^^^ Override to alter or apply batch augmentations to your batch after it is transferred to the device. -.. note:: This hook only runs on single gpu training (no data-parallel) or with sharded plugin. +.. note:: This hook only runs on single GPU training and DDP (no data-parallel). .. testcode:: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a9df7b0dc8065..2bb0a904a6eeb 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -538,7 +538,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = any other device than the one passed in as argument (unless you know what you are doing). Note: - This hook only runs on single GPU training (no data-parallel) or with sharded plugin. + This hook only runs on single GPU training and DDP (no data-parallel). If you need multi-GPU support for your custom batch objects, you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and @@ -573,6 +573,9 @@ def on_before_batch_transfer(self, batch): """ Override to alter or apply batch augmentations to your batch before it is transferred to the device. + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + Args: batch: A batch of data that needs to be altered or augmented. @@ -596,7 +599,7 @@ def on_after_batch_transfer(self, batch): Override to alter or apply batch augmentations to your batch after it is transferred to the device. Note: - This hook only runs on single gpu training (no data-parallel) or with sharded plugin. + This hook only runs on single GPU training and DDP (no data-parallel). Args: batch: A batch of data that needs to be altered or augmented. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f09808bba698a..0144b480a478e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1703,7 +1703,7 @@ def to_onnx( ) input_sample = self.example_input_array - input_sample = self.transfer_batch_to_device(input_sample) + input_sample = self._prepare_batch_for_transfer(input_sample) if "example_outputs" not in kwargs: self.eval() @@ -1766,29 +1766,6 @@ def to_torchscript( """ mode = self.training -<<<<<<< HEAD - if method == 'script': - torchscript_module = torch.jit.script(self.eval(), **kwargs) - elif method == 'trace': - # if no example inputs are provided, try to see if model has example_input_array set - if example_inputs is None: - if self.example_input_array is None: - raise ValueError( - 'Choosing method=`trace` requires either `example_inputs`' - ' or `model.example_input_array` to be defined' - ) - example_inputs = self.example_input_array - - # automatically send example inputs to the right device and use trace - example_inputs = self.prepare_batch_for_transfer(example_inputs) - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) - else: - raise ValueError( - "The 'method' parameter only supports 'script' or 'trace'," - f" but value given was: {method}" - ) - -======= with torch.no_grad(): if method == 'script': torchscript_module = torch.jit.script(self.eval(), **kwargs) @@ -1802,7 +1779,16 @@ def to_torchscript( else: raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" f"{method}") ->>>>>>> make it private + example_inputs = self.example_input_array + + # automatically send example inputs to the right device and use trace + example_inputs = self._prepare_batch_for_transfer(example_inputs) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) + else: + raise ValueError( + f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}" + ) + self.train(mode) if file_path is not None: diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index cb223ab0a79d9..39f4f6c663637 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -456,14 +456,9 @@ def transfer_batch_to_device(self, batch, device): batch.targets = batch.targets.to(device) return batch -<<<<<<< HEAD dm = CurrentTestDM() model = BoringModel() -======= - model = BoringModel() - dm = CurrentTestDM() ->>>>>>> pep and BoringModel batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer(gpus=1) From 29859f8657e094c8ef169d63378167522352cbf9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 27 Jan 2021 00:54:40 +0530 Subject: [PATCH 22/32] rebase --- pytorch_lightning/core/lightning.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 0144b480a478e..e27bf2d83bd4a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1766,19 +1766,16 @@ def to_torchscript( """ mode = self.training - with torch.no_grad(): - if method == 'script': - torchscript_module = torch.jit.script(self.eval(), **kwargs) - elif method == 'trace': - # if no example inputs are provided, try to see if model has example_input_array set - if example_inputs is None: - example_inputs = self.example_input_array - # automatically send example inputs to the right device and use trace - example_inputs = self._prepare_batch_for_transfer(example_inputs) - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) - else: - raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" - f"{method}") + if method == 'script': + torchscript_module = torch.jit.script(self.eval(), **kwargs) + elif method == 'trace': + # if no example inputs are provided, try to see if model has example_input_array set + if example_inputs is None: + if self.example_input_array is None: + raise ValueError( + 'Choosing method=`trace` requires either `example_inputs`' + ' or `model.example_input_array` to be defined.' + ) example_inputs = self.example_input_array # automatically send example inputs to the right device and use trace From 1bb334977928452f7e959c1b05540c3e7c87ac45 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 29 Jan 2021 01:20:56 +0530 Subject: [PATCH 23/32] suggestions --- docs/source/extensions/datamodules.rst | 35 ++++++++++++++++++-------- pytorch_lightning/core/hooks.py | 8 +++--- pytorch_lightning/core/lightning.py | 6 ++--- tests/core/test_datamodules.py | 9 +++++-- tests/models/test_hooks.py | 6 +++-- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index a9e53b5fce34e..170430c48a358 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -164,6 +164,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) + .. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. @@ -179,6 +180,7 @@ To define a DataModule define 5 methods: - val_dataloader(s) - test_dataloader(s) + prepare_data ^^^^^^^^^^^^ Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed @@ -196,7 +198,9 @@ settings. MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) -.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`). + +.. warning:: ``prepare_data`` is called from a single GPU. Do not use it to assign state (``self.x = y``). + setup ^^^^^ @@ -269,7 +273,6 @@ Use this method to generate the val dataloader. Usually you just wrap the datas def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=64) -.. _datamodule-test-dataloader-label: test_dataloader ^^^^^^^^^^^^^^^ @@ -284,27 +287,28 @@ Use this method to generate the test dataloader. Usually you just wrap the datas def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=64) + transfer_batch_to_device ^^^^^^^^^^^^^^^^^^^^^^^^ Override to define how you want to move an arbitrary batch to a device. -.. note:: This hook only runs on single GPU training and DDP (no data-parallel). - .. testcode:: class MNISTDataModule(LightningDataModule): def transfer_batch_to_device(self, batch, device): x = batch['x'] x = CustomDataWrapper(x) - x.to(device) + batch['x'] = x.to(device) return batch -on_before_batch_transfer -^^^^^^^^^^^^^^^^^^^^^^^^ -Override to alter or apply batch augmentations to your batch before it is transferred to the device. .. note:: This hook only runs on single GPU training and DDP (no data-parallel). + +on_before_batch_transfer +^^^^^^^^^^^^^^^^^^^^^^^^ +Override to alter or apply augmentations to your batch before it is transferred to the device. + .. testcode:: class MNISTDataModule(LightningDataModule): @@ -312,12 +316,14 @@ Override to alter or apply batch augmentations to your batch before it is transf batch['x'] = transforms(batch['x']) return batch -on_after_batch_transfer -^^^^^^^^^^^^^^^^^^^^^^^ -Override to alter or apply batch augmentations to your batch after it is transferred to the device. .. note:: This hook only runs on single GPU training and DDP (no data-parallel). + +on_after_batch_transfer +^^^^^^^^^^^^^^^^^^^^^^^ +Override to alter or apply augmentations to your batch after it is transferred to the device. + .. testcode:: class MNISTDataModule(LightningDataModule): @@ -326,6 +332,13 @@ Override to alter or apply batch augmentations to your batch after it is transfe return batch +.. note:: + This hook only runs on single GPU training and DDP (no data-parallel). This hook + will also be called when using CPU device, so adding augmentations here or in + ``on_before_batch_transfer`` means the same thing. + + + .. note:: To decouple your data from transforms you can parametrize them via ``__init__``. .. code-block:: python diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 2bb0a904a6eeb..acb9ffbe713e8 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -589,8 +589,8 @@ def on_before_batch_transfer(self, batch): return batch See Also: - - :func:`~pytorch_lightning.core.on_after_batch_transfer` - - :func:`~pytorch_lightning.core.transfer_batch_to_device` + - :func:`~pytorch_lightning.core.lightning.LightningModule.on_after_batch_transfer` + - :func:`~pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device` """ return batch @@ -614,8 +614,8 @@ def on_after_batch_transfer(self, batch): return batch See Also: - - :func:`~pytorch_lightning.core.on_before_batch_transfer` - - :func:`~pytorch_lightning.core.transfer_batch_to_device` + - :func:`~pytorch_lightning.core.lightning.LightningModule.on_before_batch_transfer` + - :func:`~pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device` """ return batch diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e27bf2d83bd4a..36cc32fd39d38 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -48,12 +48,12 @@ class LightningModule( ABC, - CheckpointHooks, - DataHooks, DeviceDtypeModuleMixin, GradInformation, - ModelHooks, ModelIO, + ModelHooks, + DataHooks, + CheckpointHooks, Module, ): # Below is for property support of JIT in PyTorch 1.7 diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 39f4f6c663637..3db3b7fe6b1a9 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -422,8 +422,13 @@ def test_step_end(self, outputs): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +<<<<<<< HEAD @mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) def test_dm_apply_batch_transfer_handler(get_module_mock): +======= +def test_dm_prepare_batch_for_transfer(tmpdir): + expected_device = torch.device('cuda', 0) +>>>>>>> suggestions class CustomBatch: @@ -444,6 +449,7 @@ def on_before_batch_transfer(self, batch): return batch def on_after_batch_transfer(self, batch): + assert batch.samples.device == batch.targets.device == expected_device self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 batch.targets *= 2 @@ -472,8 +478,7 @@ def transfer_batch_to_device(self, batch, device): model.on_after_batch_transfer = dm.on_after_batch_transfer trainer.accelerator_backend = GPUAccelerator(trainer) - batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) - expected_device = torch.device('cuda', 0) + batch_gpu = trainer.accelerator_backend.batch_to_device(batch, expected_device) assert dm.on_before_batch_transfer_hook_rank == 0 assert dm.transfer_batch_to_device_hook_rank == 1 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1abd96be72d2e..bdbe5596306ca 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -146,6 +146,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) def test_apply_batch_transfer_handler(model_getter_mock): + expected_device = torch.device('cuda', 0) class CustomBatch: def __init__(self, data): @@ -165,6 +166,7 @@ def on_before_batch_transfer(self, batch): return batch def on_after_batch_transfer(self, batch): + assert batch.samples.device == batch.targets.device == expected_device self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 batch.targets *= 2 @@ -182,9 +184,9 @@ def transfer_batch_to_device(self, batch, device): trainer = Trainer(gpus=1) # running .fit() would require us to implement custom data loaders, we mock the model reference instead + model_getter_mock.return_value = model - batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) - expected_device = torch.device('cuda', 0) + batch_gpu = trainer.accelerator_backend.batch_to_device(batch, expected_device) assert model.on_before_batch_transfer_hook_rank == 0 assert model.transfer_batch_to_device_hook_rank == 1 From be081ed3c0b36139e7aa9fa3d6f50f590d4c1ab6 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 29 Jan 2021 01:28:46 +0530 Subject: [PATCH 24/32] docs --- docs/source/extensions/datamodules.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 170430c48a358..01d69da76804e 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -274,6 +274,8 @@ Use this method to generate the val dataloader. Usually you just wrap the datas return DataLoader(self.mnist_val, batch_size=64) +.. _datamodule-test-dataloader-label: + test_dataloader ^^^^^^^^^^^^^^^ Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in ``setup``. From f1f78811a43ad6a333b59f9a135cb703a904876b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 30 Jan 2021 00:19:48 +0530 Subject: [PATCH 25/32] suggestions --- pytorch_lightning/core/hooks.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index acb9ffbe713e8..d9697c495fcbe 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -563,8 +563,8 @@ def transfer_batch_to_device(self, batch, device): return batch See Also: - - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` + - :meth:`~pytorch_lightning.utilities.apply_func.move_data_to_device` + - :meth:`~pytorch_lightning.utilities.apply_func.apply_to_collection` """ device = device or self.device return move_data_to_device(batch, device) @@ -589,8 +589,8 @@ def on_before_batch_transfer(self, batch): return batch See Also: - - :func:`~pytorch_lightning.core.lightning.LightningModule.on_after_batch_transfer` - - :func:`~pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device` + - :meth:`~pytorch_lightning.core.lightning.LightningModule.on_after_batch_transfer` + - :meth:`~pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device` """ return batch @@ -614,8 +614,8 @@ def on_after_batch_transfer(self, batch): return batch See Also: - - :func:`~pytorch_lightning.core.lightning.LightningModule.on_before_batch_transfer` - - :func:`~pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device` + - :meth:`~pytorch_lightning.core.lightning.LightningModule.on_before_batch_transfer` + - :meth:`~pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device` """ return batch @@ -631,7 +631,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: Args: checkpoint: Loaded checkpoint - Example:: def on_load_checkpoint(self, checkpoint): From e104afbbb8d452524cab30a76ec828f7492c616a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 30 Jan 2021 00:26:13 +0530 Subject: [PATCH 26/32] try fix docs --- pytorch_lightning/core/hooks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index d9697c495fcbe..6ce3877b3ccd7 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -589,8 +589,8 @@ def on_before_batch_transfer(self, batch): return batch See Also: - - :meth:`~pytorch_lightning.core.lightning.LightningModule.on_after_batch_transfer` - - :meth:`~pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device` + - :meth:`on_after_batch_transfer` + - :meth:`transfer_batch_to_device` """ return batch From 4100e3026fcb9a9c4918115c6c42d089a28b3aa0 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 30 Jan 2021 00:33:23 +0530 Subject: [PATCH 27/32] docs --- pytorch_lightning/core/hooks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 6ce3877b3ccd7..ecde729ce7db0 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -563,8 +563,8 @@ def transfer_batch_to_device(self, batch, device): return batch See Also: - - :meth:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - - :meth:`~pytorch_lightning.utilities.apply_func.apply_to_collection` + - :meth:`move_data_to_device` + - :meth:`apply_to_collection` """ device = device or self.device return move_data_to_device(batch, device) @@ -614,8 +614,8 @@ def on_after_batch_transfer(self, batch): return batch See Also: - - :meth:`~pytorch_lightning.core.lightning.LightningModule.on_before_batch_transfer` - - :meth:`~pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device` + - :meth:`on_before_batch_transfer` + - :meth:`transfer_batch_to_device` """ return batch From 2e524239e0861ab669607c5924d69204c51f2cf9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 6 Feb 2021 00:15:13 +0530 Subject: [PATCH 28/32] update name --- pytorch_lightning/core/decorators.py | 3 +-- pytorch_lightning/core/lightning.py | 6 +++--- pytorch_lightning/core/memory.py | 2 +- pytorch_lightning/loggers/tensorboard.py | 2 +- tests/core/test_datamodules.py | 4 ---- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 36f3cacc162b0..895f004baa950 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -57,8 +57,7 @@ def auto_transfer_args(self, *args, **kwargs): if not isinstance(self, LightningModule): return fn(self, *args, **kwargs) - args = self.transfer_batch_to_device(args) - kwargs = self.transfer_batch_to_device(kwargs) + args, kwargs = self.transfer_batch_to_device((args, kwargs)) return fn(self, *args, **kwargs) return auto_transfer_args diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 36cc32fd39d38..8dcfd5dc82852 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -179,7 +179,7 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None - def _prepare_batch_for_transfer(self, batch: Any, device: Optional[torch.device] = None): + def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None): batch = self.on_before_batch_transfer(batch) batch = self.transfer_batch_to_device(batch, device) batch = self.on_after_batch_transfer(batch) @@ -1703,7 +1703,7 @@ def to_onnx( ) input_sample = self.example_input_array - input_sample = self._prepare_batch_for_transfer(input_sample) + input_sample = self._apply_batch_transfer_handler(input_sample) if "example_outputs" not in kwargs: self.eval() @@ -1779,7 +1779,7 @@ def to_torchscript( example_inputs = self.example_input_array # automatically send example inputs to the right device and use trace - example_inputs = self._prepare_batch_for_transfer(example_inputs) + example_inputs = self._apply_batch_transfer_handler(example_inputs) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: raise ValueError( diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index a663fe43b976f..afb64535d1470 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -246,7 +246,7 @@ def _forward_example_input(self) -> None: trainer = self._model.trainer input_ = model.example_input_array - input_ = model._prepare_batch_for_transfer(input_, model.device) + input_ = model._apply_batch_transfer_handler(input_, model.device) if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: model.forward = torch.cuda.amp.autocast()(model.forward) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 69a0d2c28e707..0485868fa2ef1 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -211,7 +211,7 @@ def log_graph(self, model: LightningModule, input_array=None): input_array = model.example_input_array if input_array is not None: - input_array = model._prepare_batch_for_transfer(input_array) + input_array = model._apply_batch_transfer_handler(input_array) self.experiment.add_graph(model, input_array) else: rank_zero_warn( diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3db3b7fe6b1a9..fcb2168d18866 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -422,13 +422,9 @@ def test_step_end(self, outputs): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -<<<<<<< HEAD @mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) def test_dm_apply_batch_transfer_handler(get_module_mock): -======= -def test_dm_prepare_batch_for_transfer(tmpdir): expected_device = torch.device('cuda', 0) ->>>>>>> suggestions class CustomBatch: From 7c7231a7f28e887a9e7875ddd17e7bf34f42465e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 7 Feb 2021 14:59:14 +0530 Subject: [PATCH 29/32] yapf --- tests/models/test_hooks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index bdbe5596306ca..b258fb17e9015 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -149,6 +149,7 @@ def test_apply_batch_transfer_handler(model_getter_mock): expected_device = torch.device('cuda', 0) class CustomBatch: + def __init__(self, data): self.samples = data[0] self.targets = data[1] From 7a41afc40febbc4dec7f6f645563e19ad763ff79 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 7 Feb 2021 15:02:10 +0530 Subject: [PATCH 30/32] docs --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 61c5c39361e37..813d5ee978821 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -374,7 +374,7 @@ def package_list_from_file(file): import torch from torch import nn import pytorch_lightning as pl -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities import ( _NATIVE_AMP_AVAILABLE, _APEX_AVAILABLE, From 404445d57b4d912dbe09717a0ce7a4b656cd4071 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 16 Feb 2021 02:35:37 +0530 Subject: [PATCH 31/32] rebase --- pytorch_lightning/accelerators/accelerator.py | 10 +++++----- pytorch_lightning/core/datamodule.py | 7 +------ tests/core/test_datamodules.py | 1 - 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 4f4b10e2730c1..88430f037d0e4 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -106,7 +106,7 @@ def teardown(self): """ pass - def batch_to_device(self, batch: Any, device: torch.device) -> Any: + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. @@ -115,8 +115,10 @@ def batch_to_device(self, batch: Any, device: torch.device) -> Any: device: The target device """ model = self.lightning_module + if model is not None: - return model.transfer_batch_to_device(batch, device) + return model._apply_batch_transfer_handler(batch, device) + return move_data_to_device(batch, device) def on_train_start(self): @@ -136,9 +138,7 @@ def training_step(self, args): :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. """ - batch = self.to_device(args[0]) - - args[0] = batch + args[0] = self.to_device(args[0]) with self.precision_plugin.train_step_context(): with self.training_type_plugin.train_step_context(): diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 8f1273bf4d6b3..9ce05ce0b36cf 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -19,7 +19,6 @@ from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union -import torch from torch.utils.data import DataLoader, Dataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks @@ -95,11 +94,7 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn -class LightningDataModule( - CheckpointHooks, - DataHooks, - metaclass=_DataModuleWrapper -): +class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index fcb2168d18866..e2f3b559073d2 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -473,7 +473,6 @@ def transfer_batch_to_device(self, batch, device): model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_after_batch_transfer = dm.on_after_batch_transfer - trainer.accelerator_backend = GPUAccelerator(trainer) batch_gpu = trainer.accelerator_backend.batch_to_device(batch, expected_device) assert dm.on_before_batch_transfer_hook_rank == 0 From 8ddc2d7a87f73a48a4b5a102092ba9acfe943196 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 16 Feb 2021 23:10:18 +0530 Subject: [PATCH 32/32] yapf --- pytorch_lightning/core/lightning.py | 4 +--- pytorch_lightning/loggers/test_tube.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8dcfd5dc82852..b71ea0c7d7fb8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1782,9 +1782,7 @@ def to_torchscript( example_inputs = self._apply_batch_transfer_handler(example_inputs) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: - raise ValueError( - f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}" - ) + raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}") self.train(mode) diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 19484d96fcef2..5734f0fd8aebc 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -155,9 +155,7 @@ def log_graph(self, model: LightningModule, input_array=None): input_array = model.example_input_array if input_array is not None: - self.experiment.add_graph( - model, model._apply_batch_transfer_handler(input_array) - ) + self.experiment.add_graph(model, model._apply_batch_transfer_handler(input_array)) else: rank_zero_warn( 'Could not log computational graph since neither the'