Skip to content

Commit f3f282b

Browse files
awaelchlirohitgr7
andcommitted
Profile batch transfer and gradient clipping hooks (#14069)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent af9841c commit f3f282b

File tree

7 files changed

+63
-52
lines changed

7 files changed

+63
-52
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
## [1.7.2] - 2022-08-16
88

9-
### Fixed
10-
11-
- Fixed a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported ([#14117](https://github.com/Lightning-AI/lightning/pull/14117))
12-
9+
### Added
1310

14-
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
15-
- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))
11+
- Added profiling to these hooks: `on_before_batch_transfer`, `transfer_batch_to_device`, `on_after_batch_transfer`, `configure_gradient_clipping`, `clip_gradients` ([#14069](https://github.com/Lightning-AI/lightning/pull/14069))
1612

1713
### Changed
1814

1915
- Updated compatibility for LightningLite to run with the latest DeepSpeed 0.7.0 ([13967](https://github.com/Lightning-AI/lightning/pull/13967))
16+
- Raised a `MisconfigurationException` if batch transfer hooks are overriden with `IPUAccelerator` ([13961](https://github.com/Lightning-AI/lightning/pull/13961))
2017

18+
### Fixed
2119

22-
- Avoid `metadata.entry_points` deprecation warning on Python 3.10 ([#14052](https://github.com/Lightning-AI/lightning/pull/14052))
23-
24-
20+
- Fixed a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported ([#14117](https://github.com/Lightning-AI/lightning/pull/14117))
2521
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
22+
- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))
23+
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
24+
- Fixed the device placement when `LightningModule.cuda()` gets called without specifying a device index and the current cuda device was not 0 ([#14128](https://github.com/Lightning-AI/lightning/pull/14128))
25+
- Avoid `metadata.entry_points` deprecation warning on Python 3.10 ([#14052](https://github.com/Lightning-AI/lightning/pull/14052))
2626

2727

2828
## [1.7.1] - 2022-08-09
@@ -39,9 +39,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3939
- Fixed a bug that caused `ddp_find_unused_parameters` to be set `False`, whereas the intended default is `True` ([#14095](https://github.com/Lightning-AI/lightning/pull/14095))
4040

4141

42-
- Fixed the device placement when `LightningModule.cuda()` gets called without specifying a device index and the current cuda device was not 0 ([#14128](https://github.com/Lightning-AI/lightning/pull/14128))
43-
44-
4542
## [1.7.0] - 2022-08-02
4643

4744
### Added

src/pytorch_lightning/core/module.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from pytorch_lightning.core.optimizer import LightningOptimizer
3939
from pytorch_lightning.core.saving import ModelIO
4040
from pytorch_lightning.loggers import Logger, LoggerCollection
41-
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector
4241
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
4342
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings
4443
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
@@ -293,16 +292,24 @@ def _apply_batch_transfer_handler(
293292
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
294293
) -> Any:
295294
device = device or self.device
296-
datahook_selector = (
297-
_DataHookSelector(self, None) if self._trainer is None else self.trainer._data_connector._datahook_selector
298-
)
299295

300-
hook = datahook_selector.get_hook("on_before_batch_transfer")
301-
batch = hook(batch, dataloader_idx)
302-
hook = datahook_selector.get_hook("transfer_batch_to_device")
303-
batch = hook(batch, device, dataloader_idx)
304-
hook = datahook_selector.get_hook("on_after_batch_transfer")
305-
batch = hook(batch, dataloader_idx)
296+
def call_hook(hook_name, *args):
297+
if self._trainer:
298+
datahook_selector = self._trainer._data_connector._datahook_selector
299+
obj = datahook_selector.get_instance(hook_name)
300+
trainer_method = (
301+
self._trainer._call_lightning_module_hook
302+
if isinstance(obj, self.__class__)
303+
else self._trainer._call_lightning_datamodule_hook
304+
)
305+
return trainer_method(hook_name, *args)
306+
else:
307+
hook = getattr(self, hook_name)
308+
return hook(*args)
309+
310+
batch = call_hook("on_before_batch_transfer", batch, dataloader_idx)
311+
batch = call_hook("transfer_batch_to_device", batch, device, dataloader_idx)
312+
batch = call_hook("on_after_batch_transfer", batch, dataloader_idx)
306313
return batch
307314

308315
def print(self, *args, **kwargs) -> None:

src/pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ def _clip_gradients(
178178
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization:
179179
# the configuration validator disallows clipping on manual
180180
return
181-
model.configure_gradient_clipping(
181+
182+
model.trainer._call_lightning_module_hook(
183+
"configure_gradient_clipping",
182184
optimizer,
183185
optimizer_idx,
184186
gradient_clip_val=clip_val,

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import multiprocessing
1515
import os
1616
from dataclasses import dataclass, field
17-
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
17+
from typing import Any, Collection, List, Optional, Tuple, Union
1818
from weakref import proxy
1919

2020
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
@@ -527,16 +527,16 @@ def is_module(self) -> bool:
527527

528528
@dataclass
529529
class _DataHookSelector:
530-
"""Stores the info about the shared DataHooks within LightningModule and LightningDataModule.
530+
"""Stores the info about the shared DataHooks within ``LightningModule`` and ``LightningDataModule``.
531531
532-
The hook source can be
532+
The hook source can be:
533533
534-
1. a method from the :class:`~pytorch_lightning.core.module.LightningModule`,
535-
2. a method from the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
534+
1. the :class:`~pytorch_lightning.core.module.LightningModule`,
535+
2. the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
536536
537537
Arguments:
538-
model: A LightningModule
539-
datamodule: A LightningDataModule
538+
model: A ``LightningModule``
539+
datamodule: A ``LightningDataModule``
540540
"""
541541

542542
model: "pl.LightningModule"
@@ -545,27 +545,27 @@ class _DataHookSelector:
545545
default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
546546
)
547547

548-
def get_hook(self, hook_name: str) -> Callable:
548+
def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
549549
if hook_name not in self._valid_hooks:
550550
raise ValueError(
551551
f"`{hook_name}` is not a shared hook within `LightningModule` and `LightningDataModule`."
552552
f" Valid hooks are {self._valid_hooks}."
553553
)
554554

555555
if self.datamodule is None:
556-
return getattr(self.model, hook_name)
556+
return self.model
557557

558558
if is_overridden(hook_name, self.datamodule):
559559
if is_overridden(hook_name, self.model):
560560
warning_cache.warn(
561561
f"You have overridden `{hook_name}` in both `LightningModule` and `LightningDataModule`."
562562
" It will use the implementation from `LightningDataModule` instance."
563563
)
564-
return getattr(self.datamodule, hook_name)
564+
return self.datamodule
565565

566566
if is_overridden(hook_name, self.model):
567567
warning_cache.warn(
568568
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
569569
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
570570
)
571-
return getattr(self.model, hook_name)
571+
return self.model

src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class _LogOptions(TypedDict):
4444
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
4545
),
4646
"lr_scheduler_step": None,
47+
"configure_gradient_clipping": None,
48+
"clip_gradients": None,
4749
"on_before_zero_grad": _LogOptions(
4850
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
4951
),
@@ -98,6 +100,9 @@ class _LogOptions(TypedDict):
98100
"on_epoch_end": _LogOptions(
99101
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
100102
),
103+
"on_before_batch_transfer": None,
104+
"transfer_batch_to_device": None,
105+
"on_after_batch_transfer": None,
101106
"on_batch_start": _LogOptions(
102107
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
103108
),

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -471,59 +471,59 @@ def test_no_datamodule_no_overridden(self, hook_name):
471471
model, _, trainer = self.reset_instances()
472472
trainer._data_connector.attach_datamodule(model, datamodule=None)
473473
with no_warning_call(match=f"have overridden `{hook_name}` in"):
474-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
474+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
475475

476-
assert hook == getattr(model, hook_name)
476+
assert instance is model
477477

478478
def test_with_datamodule_no_overridden(self, hook_name):
479479
model, dm, trainer = self.reset_instances()
480480
trainer._data_connector.attach_datamodule(model, datamodule=dm)
481481
with no_warning_call(match=f"have overridden `{hook_name}` in"):
482-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
482+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
483483

484-
assert hook == getattr(model, hook_name)
484+
assert instance is model
485485

486486
def test_override_model_hook(self, hook_name):
487487
model, dm, trainer = self.reset_instances()
488488
trainer._data_connector.attach_datamodule(model, datamodule=dm)
489489
with no_warning_call(match=f"have overridden `{hook_name}` in"):
490-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
490+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
491491

492-
assert hook == getattr(model, hook_name)
492+
assert instance is model
493493

494494
def test_override_datamodule_hook(self, hook_name):
495495
model, dm, trainer = self.reset_instances()
496496
trainer._data_connector.attach_datamodule(model, datamodule=dm)
497497
setattr(dm, hook_name, self.overridden_func)
498498
with no_warning_call(match=f"have overridden `{hook_name}` in"):
499-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
499+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
500500

501-
assert hook == getattr(dm, hook_name)
501+
assert instance is dm
502502

503503
def test_override_both_model_and_datamodule(self, hook_name):
504504
model, dm, trainer = self.reset_instances()
505505
trainer._data_connector.attach_datamodule(model, datamodule=dm)
506506
setattr(model, hook_name, self.overridden_func)
507507
setattr(dm, hook_name, self.overridden_func)
508508
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"):
509-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
509+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
510510

511-
assert hook == getattr(dm, hook_name)
511+
assert instance is dm
512512

513513
def test_with_datamodule_override_model(self, hook_name):
514514
model, dm, trainer = self.reset_instances()
515515
trainer._data_connector.attach_datamodule(model, datamodule=dm)
516516
setattr(model, hook_name, self.overridden_func)
517517
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"):
518-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
518+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
519519

520-
assert hook == getattr(model, hook_name)
520+
assert instance is model
521521

522522

523523
def test_invalid_hook_passed_in_datahook_selector():
524524
dh_selector = _DataHookSelector(BoringModel(), None)
525525
with pytest.raises(ValueError, match="is not a shared hook"):
526-
dh_selector.get_hook("setup")
526+
dh_selector.get_instance("setup")
527527

528528

529529
def test_eval_distributed_sampler_warning(tmpdir):

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,6 @@ def __init__(self, not_supported):
187187
{
188188
"log",
189189
"log_dict",
190-
# the following are problematic as they do have `self._current_fx_name` defined some times but
191-
# not others depending on where they were called. So we cannot reliably `self.log` in them
192-
"on_before_batch_transfer",
193-
"transfer_batch_to_device",
194-
"on_after_batch_transfer",
195190
}
196191
)
197192
# remove `nn.Module` hooks
@@ -227,6 +222,9 @@ def test_fx_validator_integration(tmpdir):
227222
"on_pretrain_routine_end": "You can't",
228223
"train_dataloader": "You can't",
229224
"val_dataloader": "You can't",
225+
"on_before_batch_transfer": "You can't",
226+
"transfer_batch_to_device": "You can't",
227+
"on_after_batch_transfer": "You can't",
230228
"on_validation_end": "You can't",
231229
"on_train_end": "You can't",
232230
"on_fit_end": "You can't",
@@ -238,6 +236,8 @@ def test_fx_validator_integration(tmpdir):
238236
"on_validation_model_eval": "You can't",
239237
"on_validation_model_train": "You can't",
240238
"lr_scheduler_step": "You can't",
239+
"configure_gradient_clipping": "You can't",
240+
"clip_gradients": "You can't",
241241
"on_save_checkpoint": "You can't",
242242
"on_load_checkpoint": "You can't",
243243
"on_exception": "You can't",

0 commit comments

Comments
 (0)