Skip to content

Commit 2d9e00f

Browse files
authored
Profile batch transfer and gradient clipping hooks (#14069)
1 parent e53c4e8 commit 2d9e00f

File tree

7 files changed

+57
-40
lines changed

7 files changed

+57
-40
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))
1212

1313

14+
- Added profiling to these hooks: `on_before_batch_transfer`, `transfer_batch_to_device`, `on_after_batch_transfer`, `configure_gradient_clipping`, `clip_gradients` ([#14069](https://github.com/Lightning-AI/lightning/pull/14069))
15+
16+
1417
-
1518

1619

src/pytorch_lightning/core/module.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from pytorch_lightning.core.optimizer import LightningOptimizer
3838
from pytorch_lightning.core.saving import ModelIO
3939
from pytorch_lightning.loggers import Logger, LoggerCollection
40-
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector
4140
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
4241
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings
4342
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
@@ -291,16 +290,24 @@ def _apply_batch_transfer_handler(
291290
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
292291
) -> Any:
293292
device = device or self.device
294-
datahook_selector = (
295-
_DataHookSelector(self, None) if self._trainer is None else self.trainer._data_connector._datahook_selector
296-
)
297293

298-
hook = datahook_selector.get_hook("on_before_batch_transfer")
299-
batch = hook(batch, dataloader_idx)
300-
hook = datahook_selector.get_hook("transfer_batch_to_device")
301-
batch = hook(batch, device, dataloader_idx)
302-
hook = datahook_selector.get_hook("on_after_batch_transfer")
303-
batch = hook(batch, dataloader_idx)
294+
def call_hook(hook_name, *args):
295+
if self._trainer:
296+
datahook_selector = self._trainer._data_connector._datahook_selector
297+
obj = datahook_selector.get_instance(hook_name)
298+
trainer_method = (
299+
self._trainer._call_lightning_module_hook
300+
if isinstance(obj, self.__class__)
301+
else self._trainer._call_lightning_datamodule_hook
302+
)
303+
return trainer_method(hook_name, *args)
304+
else:
305+
hook = getattr(self, hook_name)
306+
return hook(*args)
307+
308+
batch = call_hook("on_before_batch_transfer", batch, dataloader_idx)
309+
batch = call_hook("transfer_batch_to_device", batch, device, dataloader_idx)
310+
batch = call_hook("on_after_batch_transfer", batch, dataloader_idx)
304311
return batch
305312

306313
def print(self, *args, **kwargs) -> None:

src/pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def _clip_gradients(
182182
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization:
183183
# the configuration validator disallows clipping on manual
184184
return
185-
model.configure_gradient_clipping(
185+
186+
model.trainer._call_lightning_module_hook(
187+
"configure_gradient_clipping",
186188
optimizer,
187189
optimizer_idx,
188190
gradient_clip_val=clip_val,

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import multiprocessing
1515
import os
1616
from dataclasses import dataclass, field
17-
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
17+
from typing import Any, Collection, List, Optional, Tuple, Union
1818
from weakref import proxy
1919

2020
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
@@ -527,16 +527,16 @@ def is_module(self) -> bool:
527527

528528
@dataclass
529529
class _DataHookSelector:
530-
"""Stores the info about the shared DataHooks within LightningModule and LightningDataModule.
530+
"""Stores the info about the shared DataHooks within ``LightningModule`` and ``LightningDataModule``.
531531
532-
The hook source can be
532+
The hook source can be:
533533
534-
1. a method from the :class:`~pytorch_lightning.core.module.LightningModule`,
535-
2. a method from the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
534+
1. the :class:`~pytorch_lightning.core.module.LightningModule`,
535+
2. the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
536536
537537
Arguments:
538-
model: A LightningModule
539-
datamodule: A LightningDataModule
538+
model: A ``LightningModule``
539+
datamodule: A ``LightningDataModule``
540540
"""
541541

542542
model: "pl.LightningModule"
@@ -545,27 +545,27 @@ class _DataHookSelector:
545545
default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
546546
)
547547

548-
def get_hook(self, hook_name: str) -> Callable:
548+
def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
549549
if hook_name not in self._valid_hooks:
550550
raise ValueError(
551551
f"`{hook_name}` is not a shared hook within `LightningModule` and `LightningDataModule`."
552552
f" Valid hooks are {self._valid_hooks}."
553553
)
554554

555555
if self.datamodule is None:
556-
return getattr(self.model, hook_name)
556+
return self.model
557557

558558
if is_overridden(hook_name, self.datamodule):
559559
if is_overridden(hook_name, self.model):
560560
warning_cache.warn(
561561
f"You have overridden `{hook_name}` in both `LightningModule` and `LightningDataModule`."
562562
" It will use the implementation from `LightningDataModule` instance."
563563
)
564-
return getattr(self.datamodule, hook_name)
564+
return self.datamodule
565565

566566
if is_overridden(hook_name, self.model):
567567
warning_cache.warn(
568568
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
569569
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
570570
)
571-
return getattr(self.model, hook_name)
571+
return self.model

src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class _LogOptions(TypedDict):
4444
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
4545
),
4646
"lr_scheduler_step": None,
47+
"configure_gradient_clipping": None,
48+
"clip_gradients": None,
4749
"on_before_zero_grad": _LogOptions(
4850
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
4951
),
@@ -98,6 +100,9 @@ class _LogOptions(TypedDict):
98100
"on_epoch_end": _LogOptions(
99101
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
100102
),
103+
"on_before_batch_transfer": None,
104+
"transfer_batch_to_device": None,
105+
"on_after_batch_transfer": None,
101106
"on_batch_start": _LogOptions(
102107
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
103108
),

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -471,59 +471,59 @@ def test_no_datamodule_no_overridden(self, hook_name):
471471
model, _, trainer = self.reset_instances()
472472
trainer._data_connector.attach_datamodule(model, datamodule=None)
473473
with no_warning_call(match=f"have overridden `{hook_name}` in"):
474-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
474+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
475475

476-
assert hook == getattr(model, hook_name)
476+
assert instance is model
477477

478478
def test_with_datamodule_no_overridden(self, hook_name):
479479
model, dm, trainer = self.reset_instances()
480480
trainer._data_connector.attach_datamodule(model, datamodule=dm)
481481
with no_warning_call(match=f"have overridden `{hook_name}` in"):
482-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
482+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
483483

484-
assert hook == getattr(model, hook_name)
484+
assert instance is model
485485

486486
def test_override_model_hook(self, hook_name):
487487
model, dm, trainer = self.reset_instances()
488488
trainer._data_connector.attach_datamodule(model, datamodule=dm)
489489
with no_warning_call(match=f"have overridden `{hook_name}` in"):
490-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
490+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
491491

492-
assert hook == getattr(model, hook_name)
492+
assert instance is model
493493

494494
def test_override_datamodule_hook(self, hook_name):
495495
model, dm, trainer = self.reset_instances()
496496
trainer._data_connector.attach_datamodule(model, datamodule=dm)
497497
setattr(dm, hook_name, self.overridden_func)
498498
with no_warning_call(match=f"have overridden `{hook_name}` in"):
499-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
499+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
500500

501-
assert hook == getattr(dm, hook_name)
501+
assert instance is dm
502502

503503
def test_override_both_model_and_datamodule(self, hook_name):
504504
model, dm, trainer = self.reset_instances()
505505
trainer._data_connector.attach_datamodule(model, datamodule=dm)
506506
setattr(model, hook_name, self.overridden_func)
507507
setattr(dm, hook_name, self.overridden_func)
508508
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"):
509-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
509+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
510510

511-
assert hook == getattr(dm, hook_name)
511+
assert instance is dm
512512

513513
def test_with_datamodule_override_model(self, hook_name):
514514
model, dm, trainer = self.reset_instances()
515515
trainer._data_connector.attach_datamodule(model, datamodule=dm)
516516
setattr(model, hook_name, self.overridden_func)
517517
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"):
518-
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
518+
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)
519519

520-
assert hook == getattr(model, hook_name)
520+
assert instance is model
521521

522522

523523
def test_invalid_hook_passed_in_datahook_selector():
524524
dh_selector = _DataHookSelector(BoringModel(), None)
525525
with pytest.raises(ValueError, match="is not a shared hook"):
526-
dh_selector.get_hook("setup")
526+
dh_selector.get_instance("setup")
527527

528528

529529
def test_eval_distributed_sampler_warning(tmpdir):

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,6 @@ def __init__(self, not_supported):
187187
{
188188
"log",
189189
"log_dict",
190-
# the following are problematic as they do have `self._current_fx_name` defined some times but
191-
# not others depending on where they were called. So we cannot reliably `self.log` in them
192-
"on_before_batch_transfer",
193-
"transfer_batch_to_device",
194-
"on_after_batch_transfer",
195190
}
196191
)
197192
# remove `nn.Module` hooks
@@ -227,6 +222,9 @@ def test_fx_validator_integration(tmpdir):
227222
"on_pretrain_routine_end": "You can't",
228223
"train_dataloader": "You can't",
229224
"val_dataloader": "You can't",
225+
"on_before_batch_transfer": "You can't",
226+
"transfer_batch_to_device": "You can't",
227+
"on_after_batch_transfer": "You can't",
230228
"on_validation_end": "You can't",
231229
"on_train_end": "You can't",
232230
"on_fit_end": "You can't",
@@ -238,6 +236,8 @@ def test_fx_validator_integration(tmpdir):
238236
"on_validation_model_eval": "You can't",
239237
"on_validation_model_train": "You can't",
240238
"lr_scheduler_step": "You can't",
239+
"configure_gradient_clipping": "You can't",
240+
"clip_gradients": "You can't",
241241
"on_save_checkpoint": "You can't",
242242
"on_load_checkpoint": "You can't",
243243
"on_exception": "You can't",

0 commit comments

Comments
 (0)