diff --git a/CHANGELOG.md b/CHANGELOG.md index 2256dcefeac31..26c863be63d83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -345,6 +345,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208)) +- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145)) + + +- Fixed moving batch to device before sending it to the `on_*_batch_start`/`on_*_batch_end` callbacks and model hooks ([#7378](https://github.com/PyTorchLightning/pytorch-lightning/pull/7378)) + + +- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208)) + ## [1.3.8] - 2021-07-01 @@ -361,13 +369,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed SWA to also work with `IterableDataset` ([#8172](https://github.com/PyTorchLightning/pytorch-lightning/pull/8172)) - -- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145)) - - -- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208)) - - ## [1.3.7] - 2021-06-22 ### Fixed @@ -377,6 +378,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed setting a `DistributedSampler` when using a distributed plugin in a custom accelerator ([#7814](https://github.com/PyTorchLightning/pytorch-lightning/pull/7814)) - Improved `PyTorchProfiler` chrome traces names ([#8009](https://github.com/PyTorchLightning/pytorch-lightning/pull/8009)) - Fixed moving the best score to device in `EarlyStopping` callback for TPU devices ([#7959](https://github.com/PyTorchLightning/pytorch-lightning/pull/7959)) +- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916)) ## [1.3.6] - 2021-06-15 @@ -387,7 +389,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945)) - Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942)) - Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931)) -- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916)) ## [1.3.5] - 2021-06-08 diff --git a/benchmarks/test_basic_parity.py b/benchmarks/test_basic_parity.py index bf2ddae2c0084..e01d45e4423ca 100644 --- a/benchmarks/test_basic_parity.py +++ b/benchmarks/test_basic_parity.py @@ -45,8 +45,8 @@ def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: f @pytest.mark.parametrize( 'cls_model,max_diff_speed,max_diff_memory', [ - (ParityModuleRNN, 0.05, 0.0), - (ParityModuleMNIST, 0.25, 0.0), # todo: lower this thr + (ParityModuleRNN, 0.05, 0.001), + (ParityModuleMNIST, 0.25, 0.001), # todo: lower this thr ] ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index abfb29c149bff..faae0b27b519e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl +from pytorch_lightning.plugins import DataParallelPlugin from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerFn @@ -173,8 +174,8 @@ def batch_to_device( dataloader_idx: The index of the dataloader to which the batch belongs. """ model = self.lightning_module - - if model is not None: + if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin): + # no need to transfer batch to device in DP mode return model._apply_batch_transfer_handler(batch, device, dataloader_idx) return move_data_to_device(batch, device) @@ -195,8 +196,6 @@ def training_step( - hiddens(:class:`~torch.Tensor`): Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. """ - step_kwargs = self.to_device(step_kwargs) - with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) @@ -215,8 +214,6 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S - dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple val dataloaders used) """ - step_kwargs = self.to_device(step_kwargs) - with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) @@ -232,8 +229,6 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU - dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple test dataloaders used). """ - step_kwargs = self.to_device(step_kwargs) - with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) @@ -249,8 +244,6 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: - dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple predict dataloaders used). """ - step_kwargs = self.to_device(step_kwargs) - with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) @@ -371,13 +364,6 @@ def setup_precision_plugin(self) -> None: self.optimizers = optimizers self.schedulers = schedulers - def to_device(self, step_kwargs: Dict[str, Union[Any, int]]) -> Dict[str, Union[Any, int]]: - """Pushes the batch to the root device""" - step_kwargs['batch'] = self.batch_to_device( - step_kwargs['batch'], self.root_device, dataloader_idx=step_kwargs.get('dataloader_idx', None) - ) - return step_kwargs - @property def amp_backend(self) -> Optional[LightningEnum]: if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 1c5ff56d805a6..3348727a36e61 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -13,13 +13,11 @@ # limitations under the License. import logging import os -from typing import Any, Dict, Union import torch import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.plugins import DataParallelPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException _log = logging.getLogger(__name__) @@ -51,11 +49,3 @@ def set_nvidia_flags(local_rank: int) -> None: all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") - - def to_device(self, step_kwargs: Dict[str, Union[Any, int]]) -> Dict[str, Union[Any, int]]: - # no need to transfer batch to device in DP mode - # TODO: Add support to allow batch transfer to device in Lightning for DP mode. - if not isinstance(self.training_type_plugin, DataParallelPlugin): - step_kwargs = super().to_device(step_kwargs) - - return step_kwargs diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 64df877ce68ee..ef6c885930914 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -139,6 +139,10 @@ def advance(self, batch, batch_idx, dataloader_idx): if result: self.batch_outputs[0].append(result.training_step_output) + def teardown(self) -> None: + # release memory + self._remaining_splits = None + def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: """Gets the number of active optimizers based on their frequency""" return len(self.get_active_optimizers(batch_idx)) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 7f8ef06d7687f..d63e94e062412 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -100,6 +100,9 @@ def advance( if batch is None: raise StopIteration + with self.trainer.profiler.profile("evaluation_batch_to_device"): + batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 29a76793b4648..6512a09be7a41 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -83,6 +83,9 @@ def advance( if batch is None: raise StopIteration + with self.trainer.profiler.profile("predict_batch_to_device"): + batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + with self.trainer.profiler.profile("predict_step"): self._predict_step(batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3d6fbc8dcdc1c..40005df3a37c4 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -104,6 +104,9 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ + with self.trainer.profiler.profile("training_batch_to_device"): + batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx) + with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx) self.batches_seen += 1 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 9a689fe9d725a..13c9b9f13ec23 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -287,13 +287,13 @@ def _train_batch(trainer, model, batches, current_epoch=0): out = [] for i in range(batches): out.extend([ + dict(name='on_before_batch_transfer', args=(ANY, 0)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)), + dict(name='on_after_batch_transfer', args=(ANY, 0)), # TODO: `on_batch_{start,end}` dict(name='Callback.on_batch_start', args=(trainer, model)), dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), dict(name='on_train_batch_start', args=(ANY, i, 0)), - dict(name='on_before_batch_transfer', args=(ANY, None)), - dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), - dict(name='on_after_batch_transfer', args=(ANY, None)), dict(name='forward', args=(ANY, )), dict(name='training_step', args=(ANY, i)), dict(name='training_step_end', args=(dict(loss=ANY), )), @@ -338,12 +338,12 @@ def _eval_batch(fn, trainer, model, batches, key): outputs = {key: ANY} for i in range(batches): out.extend([ + dict(name='on_before_batch_transfer', args=(ANY, 0)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)), + dict(name='on_after_batch_transfer', args=(ANY, 0)), # TODO: `{,Callback}.on_batch_{start,end}` dict(name=f'Callback.on_{fn}_batch_start', args=(trainer, model, ANY, i, 0)), dict(name=f'on_{fn}_batch_start', args=(ANY, i, 0)), - dict(name='on_before_batch_transfer', args=(ANY, None)), - dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), - dict(name='on_after_batch_transfer', args=(ANY, None)), dict(name='forward', args=(ANY, )), dict(name=f'{fn}_step', args=(ANY, i)), dict(name=f'{fn}_step_end', args=(outputs, )), @@ -358,11 +358,11 @@ def _predict_batch(trainer, model, batches): for i in range(batches): out.extend([ # TODO: `{,Callback}.on_batch_{start,end}` + dict(name='on_before_batch_transfer', args=(ANY, 0)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)), + dict(name='on_after_batch_transfer', args=(ANY, 0)), dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)), dict(name='on_predict_batch_start', args=(ANY, i, 0)), - dict(name='on_before_batch_transfer', args=(ANY, None)), - dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), - dict(name='on_after_batch_transfer', args=(ANY, None)), dict(name='forward', args=(ANY, )), dict(name='predict_step', args=(ANY, i)), # TODO: `predict_step_end` @@ -777,9 +777,9 @@ def call(hook, fn, *args, **kwargs): dm = HookedDataModule(called) trainer.fit(model, datamodule=dm) batch_transfer = [ - dict(name='on_before_batch_transfer', args=(ANY, None)), - dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), - dict(name='on_after_batch_transfer', args=(ANY, None)), + dict(name='on_before_batch_transfer', args=(ANY, 0)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)), + dict(name='on_after_batch_transfer', args=(ANY, 0)), ] expected = [ dict(name='prepare_data'), diff --git a/tests/trainer/loops/test_all.py b/tests/trainer/loops/test_all.py new file mode 100644 index 0000000000000..e0527c2905cb2 --- /dev/null +++ b/tests/trainer/loops/test_all.py @@ -0,0 +1,93 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Callback, Trainer +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class BatchHookObserverCallback(Callback): + + def on_train_batch_start(self, trainer, pl_module, batch, *args): + assert batch.device == pl_module.device + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, *args): + assert batch.device == pl_module.device + + def on_validation_batch_start(self, trainer, pl_module, batch, *args): + assert batch.device == pl_module.device + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, *args): + assert batch.device == pl_module.device + + def on_test_batch_start(self, trainer, pl_module, batch, *args): + assert batch.device == pl_module.device + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args): + assert batch.device == pl_module.device + + def on_predict_batch_start(self, trainer, pl_module, batch, *args): + assert batch.device == pl_module.device + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, *args): + assert batch.device == pl_module.device + + +class BatchHookObserverModel(BoringModel): + + def on_train_batch_start(self, batch, *args): + assert batch.device == self.device + + def on_train_batch_end(self, outputs, batch, *args): + assert batch.device == self.device + + def on_validation_batch_start(self, batch, *args): + assert batch.device == self.device + + def on_validation_batch_end(self, outputs, batch, *args): + assert batch.device == self.device + + def on_test_batch_start(self, batch, *args): + assert batch.device == self.device + + def on_test_batch_end(self, outputs, batch, *args): + assert batch.device == self.device + + def on_predict_batch_start(self, batch, *args): + assert batch.device == self.device + + def on_predict_batch_end(self, outputs, batch, *args): + assert batch.device == self.device + + +@RunIf(min_gpus=1) +def test_callback_batch_on_device(tmpdir): + """ Test that the batch object sent to the on_*_batch_start/end hooks is on the right device.""" + + batch_callback = BatchHookObserverCallback() + + model = BatchHookObserverModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + gpus=1, + callbacks=[batch_callback], + ) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 14f47a2558eff..762405c6710bc 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1533,7 +1533,7 @@ def __init__(self): def assert_dataloader_idx_hook(self, dataloader_idx): if self.trainer.training: - assert dataloader_idx is None + assert dataloader_idx == 0 elif self.trainer.validating: assert dataloader_idx == (0 if self.val_call_count <= 5 else 1) elif self.trainer.testing: