Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))


- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


## [1.2.0] - 2021-02-18

### Added
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
Expand Down Expand Up @@ -48,3 +49,11 @@ def set_nvidia_flags() -> None:
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]")

def to_device(self, batch: Any) -> Any:
# no need to transfer batch to device in DP mode
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
if not isinstance(self.training_type_plugin, DataParallelPlugin):
batch = super().to_device(batch)

return batch
21 changes: 16 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =

Note:
This hook only runs on single GPU training and DDP (no data-parallel).
If you need multi-GPU support for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
Data-Parallel support will come in near future.

Args:
batch: A batch of data that needs to be transferred to a new device.
Expand All @@ -638,6 +635,10 @@ def transfer_batch_to_device(self, batch, device):
batch = super().transfer_batch_to_device(data, device)
return batch

Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.

See Also:
- :meth:`move_data_to_device`
- :meth:`apply_to_collection`
Expand All @@ -649,10 +650,11 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
"""
Override to alter or apply batch augmentations to your batch before it is transferred to the device.

.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.
.. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future.

Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Data-Parallel support will come in near future.

Args:
batch: A batch of data that needs to be altered or augmented.
Expand All @@ -667,6 +669,10 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch

Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.

See Also:
- :meth:`on_after_batch_transfer`
- :meth:`transfer_batch_to_device`
Expand All @@ -681,6 +687,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx):

Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Data-Parallel support will come in near future.

Args:
batch: A batch of data that needs to be altered or augmented.
Expand All @@ -695,6 +702,10 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
batch['x'] = gpu_transforms(batch['x'])
return batch

Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.

See Also:
- :meth:`on_before_batch_transfer`
- :meth:`transfer_batch_to_device`
Expand Down
31 changes: 17 additions & 14 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule)
self._validate_data_hooks(model)

def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
Expand All @@ -97,6 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

def _validate_data_hooks(self, model):
# Raise Misconfiguration exception since these hooks are not supported in DP mode
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')

def attach_dataloaders(
self,
model,
Expand Down Expand Up @@ -127,22 +136,16 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N
if datamodule:

# Override loader hooks
if is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader
if is_overridden('predict_dataloader', datamodule):
model.predict_dataloader = datamodule.predict_dataloader
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
if is_overridden('on_before_batch_transfer', datamodule):
model.on_before_batch_transfer = datamodule.on_before_batch_transfer
if is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
if is_overridden('on_after_batch_transfer', datamodule):
model.on_after_batch_transfer = datamodule.on_after_batch_transfer
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))

self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer
Expand Down
53 changes: 53 additions & 0 deletions tests/accelerators/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -132,6 +135,56 @@ def training_epoch_end(self, outputs):
assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5


def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch):
"""
Test that an exception is raised when overriding batch_transfer_hooks in DP model.
"""
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)

class CustomModel(BoringModel):

def transfer_batch_to_device(self, batch, device):
batch = batch.to(device)
return batch

trainer_options = dict(
default_root_dir=tmpdir,
max_steps=7,
gpus=[0, 1],
accelerator='dp',
)

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'):
trainer.fit(model)

class CustomModel(BoringModel):

def on_before_batch_transfer(self, batch, dataloader_idx):
batch += 1
return batch

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'):
trainer.fit(model)

class CustomModel(BoringModel):

def on_after_batch_transfer(self, batch, dataloader_idx):
batch += 1
return batch

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'):
trainer.fit(model)


@RunIf(min_gpus=2)
def test_dp_training_step_dict(tmpdir):
""" This test verifies that dp properly reduces dictionaries """
Expand Down