Skip to content

Commit e78bf20

Browse files
rohitgr7awaelchli
andauthored
Raise an error if batch transfer hooks are overridden with IPUAccelerator (#13961)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 332182d commit e78bf20

File tree

7 files changed

+53
-51
lines changed

7 files changed

+53
-51
lines changed

docs/source-pytorch/accelerators/gpu_intermediate.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ after which the root node will aggregate the results.
4747
:doc:`Manual Optimization <../model/manual_optimization>` with DP. Use DDP which is more stable and at least 3x faster.
4848

4949
.. warning:: DP only supports scattering and gathering primitive collections of tensors like lists, dicts, etc.
50-
Therefore the :meth:`~pytorch_lightning.core.hooks.ModelHooks.transfer_batch_to_device` hook does not apply in
51-
this mode and if you have overridden it, it will not be called.
50+
Therefore the hooks :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_batch_transfer`,
51+
:meth:`~pytorch_lightning.core.hooks.ModelHooks.transfer_batch_to_device` and :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_after_batch_transfer`
52+
do not apply in this mode and if you have overridden any of them, an exception will be raised.
5253

5354
.. testcode::
5455
:skipif: torch.cuda.device_count() < 2

docs/source-pytorch/accelerators/ipu_basic.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,6 @@ Please see the `MNIST example <https://github.com/Lightning-AI/lightning/blob/ma
6767
* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code.
6868
* Clipping gradients is not supported.
6969
* It is not possible to use :class:`torch.utils.data.BatchSampler` in your dataloaders if you are using multiple IPUs.
70+
* IPUs handles the data transfer to the device on the host, hence the hooks :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_batch_transfer`,
71+
:meth:`~pytorch_lightning.core.hooks.ModelHooks.transfer_batch_to_device` and :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_after_batch_transfer`
72+
do not apply here and if you have overridden any of them, an exception will be raised.

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))
2020

2121

22-
-
22+
- Raised a `MisconfigurationException` if batch transfer hooks are overriden with `IPUAccelerator` ([13961](https://github.com/Lightning-AI/lightning/pull/13961))
2323

2424

2525
### Deprecated

src/pytorch_lightning/core/hooks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,9 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
665665
MisconfigurationException:
666666
If using data-parallel, ``Trainer(strategy='dp')``.
667667
668+
MisconfigurationException:
669+
If using IPUs, ``Trainer(accelerator='ipu')``.
670+
668671
See Also:
669672
- :meth:`move_data_to_device`
670673
- :meth:`apply_to_collection`
@@ -700,6 +703,9 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
700703
MisconfigurationException:
701704
If using data-parallel, ``Trainer(strategy='dp')``.
702705
706+
MisconfigurationException:
707+
If using IPUs, ``Trainer(accelerator='ipu')``.
708+
703709
See Also:
704710
- :meth:`on_after_batch_transfer`
705711
- :meth:`transfer_batch_to_device`
@@ -735,6 +741,9 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
735741
MisconfigurationException:
736742
If using data-parallel, ``Trainer(strategy='dp')``.
737743
744+
MisconfigurationException:
745+
If using IPUs, ``Trainer(accelerator='ipu')``.
746+
738747
See Also:
739748
- :meth:`on_before_batch_transfer`
740749
- :meth:`transfer_batch_to_device`

src/pytorch_lightning/trainer/configuration_validator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pytorch_lightning as pl
15+
from pytorch_lightning.accelerators.ipu import IPUAccelerator
1516
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
1617
from pytorch_lightning.strategies import DataParallelStrategy
1718
from pytorch_lightning.trainer.states import TrainerFn
@@ -45,7 +46,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
4546
elif trainer.state.fn == TrainerFn.PREDICTING:
4647
__verify_eval_loop_configuration(trainer, model, "predict")
4748

48-
__verify_dp_batch_transfer_support(trainer, model)
49+
__verify_batch_transfer_support(trainer, model)
4950
_check_deprecated_callback_hooks(trainer)
5051
# TODO: Delete _check_on_hpc_hooks in v1.8
5152
_check_on_hpc_hooks(model)
@@ -148,17 +149,22 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning
148149
raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.")
149150

150151

151-
def __verify_dp_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
152+
def __verify_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
152153
"""Raise Misconfiguration exception since these hooks are not supported in DP mode."""
153-
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
154154
batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
155155
datahook_selector = trainer._data_connector._datahook_selector
156156
for hook in batch_transfer_hooks:
157+
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
157158
if isinstance(trainer.strategy, DataParallelStrategy) and (
158159
is_overridden(hook, datahook_selector.model) or is_overridden(hook, datahook_selector.datamodule)
159160
):
160161
raise MisconfigurationException(f"Overriding `{hook}` is not supported in DP mode.")
161162

163+
if isinstance(trainer.accelerator, IPUAccelerator) and (
164+
is_overridden(hook, datahook_selector.model) or is_overridden(hook, datahook_selector.datamodule)
165+
):
166+
raise MisconfigurationException(f"Overriding `{hook}` is not supported with IPUs.")
167+
162168

163169
def __verify_manual_optimization_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
164170
if model.automatic_optimization:

tests/tests_pytorch/strategies/test_dp.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from unittest import mock
1514

16-
import pytest
1715
import torch
1816
import torch.nn.functional as F
1917
from torch.utils.data import DataLoader
2018

2119
import pytorch_lightning as pl
2220
import tests_pytorch.helpers.pipelines as tpipes
2321
import tests_pytorch.helpers.utils as tutils
24-
from pytorch_lightning import Trainer
2522
from pytorch_lightning.callbacks import EarlyStopping
2623
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
27-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2824
from tests_pytorch.helpers.datamodules import ClassifDataModule
2925
from tests_pytorch.helpers.runif import RunIf
3026
from tests_pytorch.helpers.simple_models import ClassificationModel
@@ -154,47 +150,6 @@ def _assert_extra_outputs(self, outputs):
154150
assert out.dtype is torch.float
155151

156152

157-
@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2)
158-
@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True)
159-
def test_dp_raise_exception_with_batch_transfer_hooks(mock_is_available, mock_device_count, tmpdir):
160-
"""Test that an exception is raised when overriding batch_transfer_hooks in DP model."""
161-
162-
class CustomModel(BoringModel):
163-
def transfer_batch_to_device(self, batch, device, dataloader_idx):
164-
batch = batch.to(device)
165-
return batch
166-
167-
trainer_options = dict(default_root_dir=tmpdir, max_steps=7, accelerator="gpu", devices=[0, 1], strategy="dp")
168-
169-
trainer = Trainer(**trainer_options)
170-
model = CustomModel()
171-
172-
with pytest.raises(MisconfigurationException, match=r"Overriding `transfer_batch_to_device` is not .* in DP"):
173-
trainer.fit(model)
174-
175-
class CustomModel(BoringModel):
176-
def on_before_batch_transfer(self, batch, dataloader_idx):
177-
batch += 1
178-
return batch
179-
180-
trainer = Trainer(**trainer_options)
181-
model = CustomModel()
182-
183-
with pytest.raises(MisconfigurationException, match=r"Overriding `on_before_batch_transfer` is not .* in DP"):
184-
trainer.fit(model)
185-
186-
class CustomModel(BoringModel):
187-
def on_after_batch_transfer(self, batch, dataloader_idx):
188-
batch += 1
189-
return batch
190-
191-
trainer = Trainer(**trainer_options)
192-
model = CustomModel()
193-
194-
with pytest.raises(MisconfigurationException, match=r"Overriding `on_after_batch_transfer` is not .* in DP"):
195-
trainer.fit(model)
196-
197-
198153
@RunIf(min_cuda_gpus=2)
199154
def test_dp_training_step_dict(tmpdir):
200155
"""This test verifies that dp properly reduces dictionaries."""

tests/tests_pytorch/trainer/test_config_validator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
import pytest
1515
import torch
1616

17+
import pytorch_lightning as pl
1718
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
1819
from pytorch_lightning.callbacks.callback import Callback
1920
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
21+
from pytorch_lightning.utilities import device_parser
2022
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2123
from pytorch_lightning.utilities.warnings import PossibleUserWarning
2224

@@ -192,3 +194,29 @@ def setup(self, pl_module, trainer):
192194

193195
with pytest.raises(MisconfigurationException, match="does not have a `stage` argument"):
194196
trainer.fit(model)
197+
198+
199+
@pytest.mark.parametrize("trainer_kwargs", [{"accelerator": "ipu"}, {"accelerator": "gpu", "strategy": "dp"}])
200+
@pytest.mark.parametrize("hook", ["on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer"])
201+
def test_raise_exception_with_batch_transfer_hooks(monkeypatch, hook, trainer_kwargs, tmpdir):
202+
"""Test that an exception is raised when overriding batch_transfer_hooks."""
203+
if trainer_kwargs.get("accelerator") == "gpu":
204+
match_pattern = rf"Overriding `{hook}` is not .* in DP mode."
205+
monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True)
206+
monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 2)
207+
elif trainer_kwargs.get("accelerator") == "ipu":
208+
match_pattern = rf"Overriding `{hook}` is not .* with IPUs"
209+
monkeypatch.setattr(pl.accelerators.ipu.IPUAccelerator, "is_available", lambda _: True)
210+
monkeypatch.setattr(pl.strategies.ipu, "_IPU_AVAILABLE", lambda: True)
211+
212+
def custom_method(self, batch, *_, **__):
213+
batch = batch + 1
214+
return batch
215+
216+
trainer = Trainer(default_root_dir=tmpdir, **trainer_kwargs)
217+
218+
model = BoringModel()
219+
setattr(model, hook, custom_method)
220+
221+
with pytest.raises(MisconfigurationException, match=match_pattern):
222+
trainer.fit(model)

0 commit comments

Comments
 (0)