diff --git a/CHANGELOG.md b/CHANGELOG.md index 08405cb89b392..4f721b263668f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) +- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index dd45e592bdd7e..af9ce25f902b3 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,10 +1,11 @@ import logging import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.plugins import DataParallelPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException if TYPE_CHECKING: @@ -48,3 +49,11 @@ def set_nvidia_flags() -> None: all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]") + + def to_device(self, batch: Any) -> Any: + # no need to transfer batch to device in DP mode + # TODO: Add support to allow batch transfer to device in Lightning for DP mode. + if not isinstance(self.training_type_plugin, DataParallelPlugin): + batch = super().to_device(batch) + + return batch diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9826f9d44ac2c..1399d1b3c66ba 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -615,10 +615,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = Note: This hook only runs on single GPU training and DDP (no data-parallel). - If you need multi-GPU support for your custom batch objects, you need to define your custom - :class:`~torch.nn.parallel.DistributedDataParallel` or - :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and - override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be transferred to a new device. @@ -638,6 +635,10 @@ def transfer_batch_to_device(self, batch, device): batch = super().transfer_batch_to_device(data, device) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`move_data_to_device` - :meth:`apply_to_collection` @@ -649,10 +650,11 @@ def on_before_batch_transfer(self, batch, dataloader_idx): """ Override to alter or apply batch augmentations to your batch before it is transferred to the device. - .. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future. + .. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future. Note: This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be altered or augmented. @@ -667,6 +669,10 @@ def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`on_after_batch_transfer` - :meth:`transfer_batch_to_device` @@ -681,6 +687,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): Note: This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be altered or augmented. @@ -695,6 +702,10 @@ def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`on_before_batch_transfer` - :meth:`transfer_batch_to_device` diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index fbe1cecdd837e..b3fc0b4eb7b29 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -89,6 +89,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # set up the passed in dataloaders (if needed) self.attach_dataloaders(model, train_dataloader, val_dataloaders) self.attach_datamodule(model, datamodule) + self._validate_data_hooks(model) def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): # If you supply a datamodule you can't supply train_dataloader or val_dataloaders @@ -97,6 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' ) + def _validate_data_hooks(self, model): + # Raise Misconfiguration exception since these hooks are not supported in DP mode + # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): + raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') + def attach_dataloaders( self, model, @@ -127,22 +136,16 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N if datamodule: # Override loader hooks - if is_overridden('train_dataloader', datamodule): - model.train_dataloader = datamodule.train_dataloader - if is_overridden('val_dataloader', datamodule): - model.val_dataloader = datamodule.val_dataloader - if is_overridden('test_dataloader', datamodule): - model.test_dataloader = datamodule.test_dataloader - if is_overridden('predict_dataloader', datamodule): - model.predict_dataloader = datamodule.predict_dataloader + dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader') + for method in dl_methods: + if is_overridden(method, datamodule): + setattr(model, method, getattr(datamodule, method)) # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule - if is_overridden('on_before_batch_transfer', datamodule): - model.on_before_batch_transfer = datamodule.on_before_batch_transfer - if is_overridden('transfer_batch_to_device', datamodule): - model.transfer_batch_to_device = datamodule.transfer_batch_to_device - if is_overridden('on_after_batch_transfer', datamodule): - model.on_after_batch_transfer = datamodule.on_after_batch_transfer + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if is_overridden(hook, datamodule): + setattr(model, hook, getattr(datamodule, hook)) self.trainer.datamodule = datamodule datamodule.trainer = self.trainer diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 6b84e1a70ae58..ab46aba3119fb 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch import torch.nn.functional as F from torch.utils.data import DataLoader @@ -18,8 +19,10 @@ import pytorch_lightning as pl import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core import memory +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -132,6 +135,56 @@ def training_epoch_end(self, outputs): assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5 +def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch): + """ + Test that an exception is raised when overriding batch_transfer_hooks in DP model. + """ + monkeypatch.setattr("torch.cuda.device_count", lambda: 2) + + class CustomModel(BoringModel): + + def transfer_batch_to_device(self, batch, device): + batch = batch.to(device) + return batch + + trainer_options = dict( + default_root_dir=tmpdir, + max_steps=7, + gpus=[0, 1], + accelerator='dp', + ) + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'): + trainer.fit(model) + + class CustomModel(BoringModel): + + def on_before_batch_transfer(self, batch, dataloader_idx): + batch += 1 + return batch + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'): + trainer.fit(model) + + class CustomModel(BoringModel): + + def on_after_batch_transfer(self, batch, dataloader_idx): + batch += 1 + return batch + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'): + trainer.fit(model) + + @RunIf(min_gpus=2) def test_dp_training_step_dict(tmpdir): """ This test verifies that dp properly reduces dictionaries """